Skip to content
Snippets Groups Projects
Commit f67979ae authored by NolanB's avatar NolanB
Browse files

Commit for help fixing tests

parent 9c55c683
No related branches found
No related tags found
No related merge requests found
Pipeline #79363 failed
......@@ -6,7 +6,7 @@ import tempfile
from contextlib import contextmanager
from typing import NewType, Tuple
import requests
# import requests
import zstandard as zstd
from apistar.exceptions import ErrorResponse
......@@ -77,7 +77,7 @@ def create_archive(path: FilePath) -> Archive:
class TrainingMixin(object):
def publish_model_version(self, client, model_path, model_id):
def publish_model_version(self, model_path, model_id):
# Create the zst archive, get its hash and size
with create_archive(path=model_path) as (
path_to_archive,
......@@ -87,7 +87,7 @@ class TrainingMixin(object):
):
# Create a new model version with hash and size
model_version_details = self.create_model_version(
client=client,
client=self.api_client,
model_id=model_id,
hash=hash,
size=size,
......@@ -102,12 +102,17 @@ class TrainingMixin(object):
# Update the model version with state, configuration parsed, tag, description (defaults to name of the worker)
self.update_model_version(
client=client,
client=self.api_client,
model_version_details=model_version_details,
)
def create_model_version(
client: ArkindexClient, model_id: str, hash: str, size: int, archive_hash: str
self,
client: ArkindexClient,
model_id: str,
hash: str,
size: int,
archive_hash: str,
) -> dict:
# Create a new model version with hash and size
try:
......@@ -129,12 +134,12 @@ class TrainingMixin(object):
return
return model_version_details
def upload_to_s3(archive_path: str, model_version_details: dict) -> None:
def upload_to_s3(self, archive_path: str, model_version_details: dict) -> None:
s3_put_url = model_version_details.get("s3_put_url")
logger.info("Uploading to s3...")
# Upload the archive on s3
with open(archive_path, "rb") as archive:
r = requests.put(
r = self.request.put(
url=s3_put_url,
data=archive,
headers={"Content-Type": "application/zstd"},
......@@ -142,17 +147,23 @@ class TrainingMixin(object):
r.raise_for_status()
def update_model_version(
client: ArkindexClient, model_version_details: dict
self,
model_version_details: dict,
description: str = None,
configuration: dict = None,
tag: str = None,
) -> None:
logger.info("Updating the model version...")
try:
client.request(
# request or requests ?
self.request(
"UpdateModelVersion",
id=model_version_details.get("id"),
body={
"state": "available",
"description": "DOC UFCN",
"configuration": {},
"description": description,
"configuration": configuration,
"tag": tag,
},
)
except ErrorResponse as e:
......
pytest==7.1.1
pytest-mock==3.7.0
pytest-responses==0.5.0
requests==2.27.1
Wow this is actually the data of the best model ever created on Arkindex
\ No newline at end of file
Wow this is actually the data of the best model ever created on Arkindex
\ No newline at end of file
from http import client
import imp
import pytest
from arkindex_worker.worker.training import create_archive, TrainingMixin
import os
import responses
from responses import matchers
from pathlib import Path
from arkindex.mock import MockApiClient
def test_create_archive_folder():
model_file_dir = Path("tests/samples/model_files")
with create_archive(path=model_file_dir) as (
zst_archive_path,
hash,
size,
archive_hash,
):
assert os.path.exists(zst_archive_path), "The archive was not created"
assert (
hash == "7dd70931222ef0496ea75e5aee674043"
), "Hash was not properly computed"
assert 300 < size < 700
assert not os.path.exists(zst_archive_path), "Auto removal failed"
def test_create_model_version():
api_client = MockApiClient()
"""A new model version is returned"""
model_id = "fake_model_id"
model_version_id = "fake_model_version_id"
# Create a model archive and keep its hash and size.
model_files_dir = Path("tests/samples/model_files")
# model_file_path = model_files_dir / "model_file.pth"
training = TrainingMixin()
with create_archive(path=model_files_dir) as (
zst_archive_path,
hash,
size,
archive_hash,
):
model_version_details = {
"id": model_version_id,
"model_id": model_id,
"hash": hash,
"archive_hash": archive_hash,
"size": size,
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
responses.add(
responses.POST,
f"http://testserver/api/v1/model/{model_id}/versions/",
status=200,
match=[
matchers.json_params_matcher(
{"hash": hash, "archive_hash": archive_hash, "size": size}
)
],
json=model_version_details,
)
assert (
training.create_model_version(api_client, model_id, hash, size, archive_hash)
== model_version_details
)
# def test_retrieve_created_model_version(api_client, samples_dir):
# """There is an existing model version in Created mode, A 400 was raised.
# But the model is still returned in error content
# """
# model_id = "fake_model_id"
# model_version_id = "fake_model_version_id"
# # Create a model archive and keep its hash and size.
# model_file_path = samples_dir / "model_file.pth"
# training = TrainingMixin()
# with create_archive(path=model_file_path) as (
# zst_archive_path,
# hash,
# size,
# archive_hash,
# ):
# model_version_details = {
# "id": model_version_id,
# "model_id": model_id,
# "hash": hash,
# "archive_hash": archive_hash,
# "size": size,
# "s3_url": "http://hehehe.com",
# "s3_put_url": "http://hehehe.com",
# }
# responses.add(
# responses.POST,
# f"http://testserver/api/v1/model/{model_id}/versions/",
# status=400,
# match=[
# matchers.json_params_matcher(
# {"hash": hash, "archive_hash": archive_hash, "size": size}
# )
# ],
# json={"hash": model_version_details},
# )
# assert (
# training.create_model_version(api_client, model_id, hash, size, archive_hash)
# == model_version_details
# )
# def test_retrieve_available_model_version(api_client, samples_dir):
# """Raise error when there is an existing model version in Available mode"""
# model_id = "fake_model_id"
# # Create a model archive and keep its hash and size.
# model_file_path = samples_dir / "model_file.pth"
# training = TrainingMixin()
# with create_archive(path=model_file_path) as (
# zst_archive_path,
# hash,
# size,
# archive_hash,
# ):
# responses.add(
# responses.POST,
# f"http://testserver/api/v1/model/{model_id}/versions/",
# status=403,
# match=[
# matchers.json_params_matcher(
# {"hash": hash, "archive_hash": archive_hash, "size": size}
# )
# ],
# )
# with pytest.raises(Exception):
# training.create_model_version(api_client, model_id, hash, size, archive_hash)
# def test_handle_s3_uploading_errors(samples_dir):
# s3_endpoint_url = "http://s3.localhost.com"
# responses.add_passthru(s3_endpoint_url)
# responses.add(responses.Response(method="PUT", url=s3_endpoint_url, status=400))
# file_path = samples_dir / "model_file.pth"
# training = TrainingMixin()
# with pytest.raises(Exception):
# training.upload_to_s3(file_path, {"s3_put_url": s3_endpoint_url})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment