From 26e02ac040ee487dedf2ff4f3b4f0f827f7d09d7 Mon Sep 17 00:00:00 2001
From: NolanB <nboukachab@teklia.com>
Date: Wed, 17 Aug 2022 16:44:44 +0200
Subject: [PATCH] Fix the error with the code review

---
 .isort.cfg                                  |  2 +-
 arkindex_worker/worker/training.py          | 11 +++--------
 tests/test_elements_worker/test_training.py | 11 +++++------
 3 files changed, 9 insertions(+), 15 deletions(-)

diff --git a/.isort.cfg b/.isort.cfg
index f03c5435..ad4d2fb8 100644
--- a/.isort.cfg
+++ b/.isort.cfg
@@ -8,4 +8,4 @@ line_length = 88
 
 default_section=FIRSTPARTY
 known_first_party = arkindex,arkindex_common
-known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,setuptools,sh,shapely,tenacity,yaml,zstandard
+known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,responses,setuptools,sh,shapely,tenacity,yaml,zstandard
diff --git a/arkindex_worker/worker/training.py b/arkindex_worker/worker/training.py
index 8ffe9048..88bebe38 100644
--- a/arkindex_worker/worker/training.py
+++ b/arkindex_worker/worker/training.py
@@ -6,11 +6,10 @@ import tempfile
 from contextlib import contextmanager
 from typing import NewType, Tuple
 
-# import requests
+import requests
 import zstandard as zstd
 from apistar.exceptions import ErrorResponse
 
-from arkindex import ArkindexClient
 from arkindex_worker import logger
 
 CHUNK_SIZE = 1024
@@ -87,7 +86,6 @@ class TrainingMixin(object):
         ):
             # Create a new model version with hash and size
             model_version_details = self.create_model_version(
-                client=self.api_client,
                 model_id=model_id,
                 hash=hash,
                 size=size,
@@ -102,13 +100,11 @@ class TrainingMixin(object):
 
         # Update the model version with state, configuration parsed, tag, description (defaults to name of the worker)
         self.update_model_version(
-            client=self.api_client,
             model_version_details=model_version_details,
         )
 
     def create_model_version(
         self,
-        client: ArkindexClient,
         model_id: str,
         hash: str,
         size: int,
@@ -116,7 +112,7 @@ class TrainingMixin(object):
     ) -> dict:
         # Create a new model version with hash and size
         try:
-            model_version_details = client.request(
+            model_version_details = self.request(
                 "CreateModelVersion",
                 id=model_id,
                 body={"hash": hash, "size": size, "archive_hash": archive_hash},
@@ -139,7 +135,7 @@ class TrainingMixin(object):
         logger.info("Uploading to s3...")
         # Upload the archive on s3
         with open(archive_path, "rb") as archive:
-            r = self.request.put(
+            r = requests.put(
                 url=s3_put_url,
                 data=archive,
                 headers={"Content-Type": "application/zstd"},
@@ -155,7 +151,6 @@ class TrainingMixin(object):
     ) -> None:
         logger.info("Updating the model version...")
         try:
-            # request or requests ?
             self.request(
                 "UpdateModelVersion",
                 id=model_version_details.get("id"),
diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py
index a20e3e4b..4661ac13 100644
--- a/tests/test_elements_worker/test_training.py
+++ b/tests/test_elements_worker/test_training.py
@@ -1,13 +1,13 @@
-from http import client
-import imp
-import pytest
-from arkindex_worker.worker.training import create_archive, TrainingMixin
+# -*- coding: utf-8 -*-
 import os
+from pathlib import Path
+
 import responses
 from responses import matchers
-from pathlib import Path
 
 from arkindex.mock import MockApiClient
+from arkindex_worker.worker.training import TrainingMixin, create_archive
+
 
 def test_create_archive_folder():
     model_file_dir = Path("tests/samples/model_files")
@@ -141,7 +141,6 @@ def test_create_model_version():
 #         training.create_model_version(api_client, model_id, hash, size, archive_hash)
 
 
-
 # def test_handle_s3_uploading_errors(samples_dir):
 #     s3_endpoint_url = "http://s3.localhost.com"
 #     responses.add_passthru(s3_endpoint_url)
-- 
GitLab