Skip to content
Snippets Groups Projects
Commit 0b2ee596 authored by NolanB's avatar NolanB
Browse files

Commit for help fixing tests after add TrainingWorker class

parent 26e02ac0
No related branches found
No related tags found
No related merge requests found
Pipeline #79368 failed
......@@ -128,6 +128,7 @@ class TrainingMixin(object):
):
logger.error(model_version_details[0])
return
return model_version_details
def upload_to_s3(self, archive_path: str, model_version_details: dict) -> None:
......
......@@ -5,10 +5,15 @@ from pathlib import Path
import responses
from responses import matchers
from arkindex.mock import MockApiClient
from arkindex_worker.worker import BaseWorker
from arkindex_worker.worker.training import TrainingMixin, create_archive
class TrainingWorker(BaseWorker, TrainingMixin):
def __init__(self):
super().setup_api_client()
def test_create_archive_folder():
model_file_dir = Path("tests/samples/model_files")
......@@ -28,15 +33,13 @@ def test_create_archive_folder():
def test_create_model_version():
api_client = MockApiClient()
"""A new model version is returned"""
model_id = "fake_model_id"
model_version_id = "fake_model_version_id"
# Create a model archive and keep its hash and size.
model_files_dir = Path("tests/samples/model_files")
# model_file_path = model_files_dir / "model_file.pth"
training = TrainingMixin()
training = TrainingWorker()
with create_archive(path=model_files_dir) as (
zst_archive_path,
hash,
......@@ -66,7 +69,7 @@ def test_create_model_version():
)
assert (
training.create_model_version(api_client, model_id, hash, size, archive_hash)
training.create_model_version(model_id, hash, size, archive_hash)
== model_version_details
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment