Skip to content
Snippets Groups Projects
Commit 26e02ac0 authored by NolanB's avatar NolanB
Browse files

Fix the error with the code review

parent f67979ae
No related branches found
No related tags found
No related merge requests found
Pipeline #79366 failed
......@@ -8,4 +8,4 @@ line_length = 88
default_section=FIRSTPARTY
known_first_party = arkindex,arkindex_common
known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,setuptools,sh,shapely,tenacity,yaml,zstandard
known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,responses,setuptools,sh,shapely,tenacity,yaml,zstandard
......@@ -6,11 +6,10 @@ import tempfile
from contextlib import contextmanager
from typing import NewType, Tuple
# import requests
import requests
import zstandard as zstd
from apistar.exceptions import ErrorResponse
from arkindex import ArkindexClient
from arkindex_worker import logger
CHUNK_SIZE = 1024
......@@ -87,7 +86,6 @@ class TrainingMixin(object):
):
# Create a new model version with hash and size
model_version_details = self.create_model_version(
client=self.api_client,
model_id=model_id,
hash=hash,
size=size,
......@@ -102,13 +100,11 @@ class TrainingMixin(object):
# Update the model version with state, configuration parsed, tag, description (defaults to name of the worker)
self.update_model_version(
client=self.api_client,
model_version_details=model_version_details,
)
def create_model_version(
self,
client: ArkindexClient,
model_id: str,
hash: str,
size: int,
......@@ -116,7 +112,7 @@ class TrainingMixin(object):
) -> dict:
# Create a new model version with hash and size
try:
model_version_details = client.request(
model_version_details = self.request(
"CreateModelVersion",
id=model_id,
body={"hash": hash, "size": size, "archive_hash": archive_hash},
......@@ -139,7 +135,7 @@ class TrainingMixin(object):
logger.info("Uploading to s3...")
# Upload the archive on s3
with open(archive_path, "rb") as archive:
r = self.request.put(
r = requests.put(
url=s3_put_url,
data=archive,
headers={"Content-Type": "application/zstd"},
......@@ -155,7 +151,6 @@ class TrainingMixin(object):
) -> None:
logger.info("Updating the model version...")
try:
# request or requests ?
self.request(
"UpdateModelVersion",
id=model_version_details.get("id"),
......
from http import client
import imp
import pytest
from arkindex_worker.worker.training import create_archive, TrainingMixin
# -*- coding: utf-8 -*-
import os
from pathlib import Path
import responses
from responses import matchers
from pathlib import Path
from arkindex.mock import MockApiClient
from arkindex_worker.worker.training import TrainingMixin, create_archive
def test_create_archive_folder():
model_file_dir = Path("tests/samples/model_files")
......@@ -141,7 +141,6 @@ def test_create_model_version():
# training.create_model_version(api_client, model_id, hash, size, archive_hash)
# def test_handle_s3_uploading_errors(samples_dir):
# s3_endpoint_url = "http://s3.localhost.com"
# responses.add_passthru(s3_endpoint_url)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment