From 26e02ac040ee487dedf2ff4f3b4f0f827f7d09d7 Mon Sep 17 00:00:00 2001 From: NolanB <nboukachab@teklia.com> Date: Wed, 17 Aug 2022 16:44:44 +0200 Subject: [PATCH] Fix the error with the code review --- .isort.cfg | 2 +- arkindex_worker/worker/training.py | 11 +++-------- tests/test_elements_worker/test_training.py | 11 +++++------ 3 files changed, 9 insertions(+), 15 deletions(-) diff --git a/.isort.cfg b/.isort.cfg index f03c5435..ad4d2fb8 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -8,4 +8,4 @@ line_length = 88 default_section=FIRSTPARTY known_first_party = arkindex,arkindex_common -known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,setuptools,sh,shapely,tenacity,yaml,zstandard +known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,responses,setuptools,sh,shapely,tenacity,yaml,zstandard diff --git a/arkindex_worker/worker/training.py b/arkindex_worker/worker/training.py index 8ffe9048..88bebe38 100644 --- a/arkindex_worker/worker/training.py +++ b/arkindex_worker/worker/training.py @@ -6,11 +6,10 @@ import tempfile from contextlib import contextmanager from typing import NewType, Tuple -# import requests +import requests import zstandard as zstd from apistar.exceptions import ErrorResponse -from arkindex import ArkindexClient from arkindex_worker import logger CHUNK_SIZE = 1024 @@ -87,7 +86,6 @@ class TrainingMixin(object): ): # Create a new model version with hash and size model_version_details = self.create_model_version( - client=self.api_client, model_id=model_id, hash=hash, size=size, @@ -102,13 +100,11 @@ class TrainingMixin(object): # Update the model version with state, configuration parsed, tag, description (defaults to name of the worker) self.update_model_version( - client=self.api_client, model_version_details=model_version_details, ) def create_model_version( self, - client: ArkindexClient, model_id: str, hash: str, size: int, @@ -116,7 +112,7 @@ class TrainingMixin(object): ) -> dict: # Create a new model version with hash and size try: - model_version_details = client.request( + model_version_details = self.request( "CreateModelVersion", id=model_id, body={"hash": hash, "size": size, "archive_hash": archive_hash}, @@ -139,7 +135,7 @@ class TrainingMixin(object): logger.info("Uploading to s3...") # Upload the archive on s3 with open(archive_path, "rb") as archive: - r = self.request.put( + r = requests.put( url=s3_put_url, data=archive, headers={"Content-Type": "application/zstd"}, @@ -155,7 +151,6 @@ class TrainingMixin(object): ) -> None: logger.info("Updating the model version...") try: - # request or requests ? self.request( "UpdateModelVersion", id=model_version_details.get("id"), diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py index a20e3e4b..4661ac13 100644 --- a/tests/test_elements_worker/test_training.py +++ b/tests/test_elements_worker/test_training.py @@ -1,13 +1,13 @@ -from http import client -import imp -import pytest -from arkindex_worker.worker.training import create_archive, TrainingMixin +# -*- coding: utf-8 -*- import os +from pathlib import Path + import responses from responses import matchers -from pathlib import Path from arkindex.mock import MockApiClient +from arkindex_worker.worker.training import TrainingMixin, create_archive + def test_create_archive_folder(): model_file_dir = Path("tests/samples/model_files") @@ -141,7 +141,6 @@ def test_create_model_version(): # training.create_model_version(api_client, model_id, hash, size, archive_hash) - # def test_handle_s3_uploading_errors(samples_dir): # s3_endpoint_url = "http://s3.localhost.com" # responses.add_passthru(s3_endpoint_url) -- GitLab