Skip to content
Snippets Groups Projects
Commit 40a69a18 authored by NolanB's avatar NolanB
Browse files

Commit for help fixing tests

parent f41589b5
No related branches found
No related tags found
No related merge requests found
Pipeline #79374 failed
......@@ -8,4 +8,4 @@ line_length = 88
default_section=FIRSTPARTY
known_first_party = arkindex,arkindex_common
known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,setuptools,sh,shapely,tenacity,yaml,zstandard
known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,responses,setuptools,sh,shapely,tenacity,yaml,zstandard
......@@ -26,6 +26,7 @@ from arkindex_worker.worker import BaseWorker, ElementsWorker
from arkindex_worker.worker.transcription import TextOrientation
FIXTURES_DIR = Path(__file__).resolve().parent / "data"
SAMPLES_DIR = Path("tests/samples")
__yaml_cache = {}
......@@ -276,6 +277,11 @@ def fake_transcriptions_small():
return json.load(f)
@pytest.fixture
def model_file_dir():
return SAMPLES_DIR / "model_files"
@pytest.fixture
def fake_dummy_worker():
api_client = MockApiClient()
......
# -*- coding: utf-8 -*-
# import json
import os
from pathlib import Path
import pytest
import responses
from apistar.exceptions import ErrorResponse
from arkindex.mock import MockApiClient
from arkindex_worker.worker import BaseWorker
from arkindex_worker.worker.training import TrainingMixin, create_archive
# import pytest
# import responses
# from responses import matchers
class TrainingWorker(BaseWorker, TrainingMixin):
def __init__(self):
super().setup_api_client()
pass
def test_create_archive_folder():
model_file_dir = Path("tests/samples/model_files")
def test_create_archive_folder(model_file_dir):
with create_archive(path=model_file_dir) as (
zst_archive_path,
hash,
......@@ -28,9 +26,7 @@ def test_create_archive_folder():
):
assert os.path.exists(zst_archive_path), "The archive was not created"
assert (
# hash == "7dd70931222ef0496ea75e5aee674043"
hash
== "c5aedde18a768757351068b840c8c8f9"
hash == "c5aedde18a768757351068b840c8c8f9"
), "Hash was not properly computed"
assert 300 < size < 700
......@@ -41,120 +37,114 @@ def test_create_model_version():
"""A new model version is returned"""
model_id = "fake_model_id"
model_version_id = "fake_model_version_id"
# Create a model archive and keep its hash and size.
model_files_dir = Path("tests/samples/model_files")
# model_file_path = model_files_dir / "model_file.pth"
training = TrainingWorker()
training.api_client = MockApiClient()
with create_archive(path=model_files_dir) as (
zst_archive_path,
hash,
size,
archive_hash,
):
model_version_details = {
"id": model_version_id,
"model_id": model_id,
"hash": hash,
"archive_hash": archive_hash,
"size": size,
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
response_mock = training.api_client.add_response(
hash = "hash"
archive_hash = "archive_hash"
size = "30"
model_version_details = {
"id": model_version_id,
"model_id": model_id,
"hash": hash,
"archive_hash": archive_hash,
"size": size,
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
training.api_client.add_response(
"CreateModelVersion",
id=model_id,
response=model_version_details,
body={"hash": hash, "archive_hash": archive_hash, "size": size},
)
print(response_mock)
assert (
training.create_model_version(model_id, hash, size, archive_hash)
== response_mock
== model_version_details
)
# def test_retrieve_created_model_version(api_client, samples_dir):
# """There is an existing model version in Created mode, A 400 was raised.
# But the model is still returned in error content
# """
# model_id = "fake_model_id"
# model_version_id = "fake_model_version_id"
# # Create a model archive and keep its hash and size.
# model_file_path = samples_dir / "model_file.pth"
# training = TrainingWorker()
# training.api_client = MockApiClient()
# with create_archive(path=model_file_path) as (
# zst_archive_path,
# hash,
# size,
# archive_hash,
# ):
# model_version_details = {
# "id": model_version_id,
# "model_id": model_id,
# "hash": hash,
# "archive_hash": archive_hash,
# "size": size,
# "s3_url": "http://hehehe.com",
# "s3_put_url": "http://hehehe.com",
# }
# response_mock = training.api_client.add_response(
@pytest.mark.parametrize(
"process_exception, status_code",
[
# (None, 200)
(
ErrorResponse(
title="Mock error response", status_code=400, content="Bad gateway"
),
400,
),
(
ErrorResponse(
title="Mock error response", status_code=403, content="Bad gateway"
),
403,
),
],
)
def test_retrieve_created_model_version(process_exception, status_code):
"""There is an existing model version in Created mode, A 400 was raised.
But the model is still returned in error content
"""
model_id = "fake_model_id"
training = TrainingWorker()
training.api_client = MockApiClient()
hash = "hash"
archive_hash = "archive_hash"
size = "30"
training.api_client.add_error_response(
"CreateModelVersion",
id=model_id,
status_code=status_code,
body={"hash": hash, "archive_hash": archive_hash, "size": size},
)
# )
if process_exception.status_code == status_code:
assert (
training.create_model_version(model_id, hash, size, archive_hash)
== process_exception
)
# # responses.add(
# # responses.POST,
# # f"http://testserver/api/v1/model/{model_id}/versions/",
# # status=400,
# # match=[
# # matchers.json_params_matcher(
# # {"hash": hash, "archive_hash": archive_hash, "size": size}
# # )
# # ],
# # json={"hash": model_version_details},
# # )
# assert (
# training.create_model_version(api_client, model_id, hash, size, archive_hash)
# == model_version_details
# )
# with pytest.raises(Exception):
# training.create_model_version(model_id, hash, size, archive_hash)
# def test_retrieve_available_model_version(api_client, samples_dir):
# def test_retrieve_available_model_version():
# """Raise error when there is an existing model version in Available mode"""
# model_id = "fake_model_id"
# # Create a model archive and keep its hash and size.
# model_file_path = samples_dir / "model_file.pth"
# model_version_id = "fake_model_version_id"
# training = TrainingWorker()
# with create_archive(path=model_file_path) as (
# zst_archive_path,
# hash,
# size,
# archive_hash,
# ):
# responses.add(
# responses.POST,
# f"http://testserver/api/v1/model/{model_id}/versions/",
# status=403,
# match=[
# matchers.json_params_matcher(
# {"hash": hash, "archive_hash": archive_hash, "size": size}
# )
# ],
# )
# training.api_client = MockApiClient()
# hash = "hash"
# archive_hash = "archive_hash"
# size = "30"
# model_version_details = {
# "id": model_version_id,
# "model_id": model_id,
# "hash": hash,
# "archive_hash": archive_hash,
# "size": size,
# "s3_url": "http://hehehe.com",
# "s3_put_url": "http://hehehe.com",
# }
# training.api_client.add_error_response(
# "CreateModelVersion",
# id=model_id,
# status_code=403,
# body={"hash": hash, "archive_hash": archive_hash, "size": size}
# )
# with pytest.raises(Exception):
# training.create_model_version(api_client, model_id, hash, size, archive_hash)
# training.create_model_version(model_id, hash, size, archive_hash)
# def test_handle_s3_uploading_errors(samples_dir):
# s3_endpoint_url = "http://s3.localhost.com"
# responses.add_passthru(s3_endpoint_url)
# responses.add(responses.Response(method="PUT", url=s3_endpoint_url, status=400))
# file_path = samples_dir / "model_file.pth"
# training = TrainingWorker()
# with pytest.raises(Exception):
# training.upload_to_s3(file_path, {"s3_put_url": s3_endpoint_url})
def test_handle_s3_uploading_errors(model_file_dir):
training = TrainingWorker()
training.api_client = MockApiClient()
s3_endpoint_url = "http://s3.localhost.com"
responses.add_passthru(s3_endpoint_url)
responses.add(responses.Response(method="PUT", url=s3_endpoint_url, status=400))
file_path = model_file_dir / "model_file.pth"
with pytest.raises(Exception):
training.upload_to_s3(file_path, {"s3_put_url": s3_endpoint_url})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment