Skip to content
Snippets Groups Projects
Commit 4bd8ff64 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Bastien Abadie
Browse files

Disable training methods usage in read only mode

parent 76796f15
No related branches found
No related tags found
1 merge request!207Disable training methods usage in read only mode
Pipeline #79543 passed
......@@ -87,6 +87,11 @@ class TrainingMixin(object):
This method creates a model archive and its associated hash,
to create a unique version that will be stored on a bucket and published on arkindex.
"""
if self.is_read_only:
logger.warning(
"Cannot publish a new model version as this worker is in read-only mode"
)
return
# Create the zst archive, get its hash and size
with create_archive(path=model_path) as (
......@@ -131,6 +136,11 @@ class TrainingMixin(object):
- The version is in `Created` state: this version's details is used
- The version is in `Available` state: you cannot create twice the same version, an error is raised
"""
if self.is_read_only:
logger.warning(
"Cannot create a new model version as this worker is in read-only mode"
)
return
# Create a new model version with hash and size
try:
......@@ -171,6 +181,11 @@ class TrainingMixin(object):
"""
Upload the archive of the model's files to an Amazon s3 compatible storage
"""
if self.is_read_only:
logger.warning(
"Cannot upload this archive as this worker is in read-only mode"
)
return
s3_put_url = model_version_details.get("s3_put_url")
logger.info("Uploading to s3...")
......@@ -191,6 +206,11 @@ class TrainingMixin(object):
"""
Update the specified model version to the state `Available` and use the given information"
"""
if self.is_read_only:
logger.warning(
"Cannot update this model version as this worker is in read-only mode"
)
return
model_version_id = model_version_details.get("id")
logger.info(f"Updating model version ({model_version_id})")
......
# -*- coding: utf-8 -*-
import os
import sys
import pytest
import responses
......@@ -17,6 +18,15 @@ class TrainingWorker(BaseWorker, TrainingMixin):
pass
@pytest.fixture
def mock_training_worker(monkeypatch):
monkeypatch.setattr(sys, "argv", ["worker"])
training_worker = TrainingWorker()
training_worker.api_client = MockApiClient()
training_worker.args = training_worker.parser.parse_args()
return training_worker
def test_create_archive(model_file_dir):
"""Create an archive when the model's file is in a folder"""
......@@ -46,13 +56,11 @@ def test_create_archive(model_file_dir):
(None, ""),
],
)
def test_create_model_version(tag, description):
def test_create_model_version(mock_training_worker, tag, description):
"""A new model version is returned"""
model_version_id = "fake_model_version_id"
model_id = "fake_model_id"
training = TrainingWorker()
training.api_client = MockApiClient()
model_hash = "hash"
archive_hash = "archive_hash"
size = "30"
......@@ -68,7 +76,7 @@ def test_create_model_version(tag, description):
"s3_put_url": "http://hehehe.com",
}
training.api_client.add_response(
mock_training_worker.api_client.add_response(
"CreateModelVersion",
id=model_id,
response=model_version_details,
......@@ -81,7 +89,7 @@ def test_create_model_version(tag, description):
},
)
assert (
training.create_model_version(
mock_training_worker.create_model_version(
model_id, model_hash, size, archive_hash, tag, description
)
== model_version_details
......@@ -126,7 +134,7 @@ def test_create_model_version(tag, description):
({"hash": ["A version for this model with this hash already exists."]}, 403),
],
)
def test_retrieve_created_model_version(content, status_code):
def test_retrieve_created_model_version(mock_training_worker, content, status_code):
"""
If there is an existing model version in Created mode,
A 400 was raised, but the model is still returned in error content.
......@@ -135,14 +143,12 @@ def test_retrieve_created_model_version(content, status_code):
"""
model_id = "fake_model_id"
training = TrainingWorker()
training.api_client = MockApiClient()
model_hash = "hash"
archive_hash = "archive_hash"
size = "30"
tag = "tag"
description = "description"
training.api_client.add_error_response(
mock_training_worker.api_client.add_error_response(
"CreateModelVersion",
id=model_id,
status_code=status_code,
......@@ -157,26 +163,24 @@ def test_retrieve_created_model_version(content, status_code):
)
if status_code == 400:
assert (
training.create_model_version(
mock_training_worker.create_model_version(
model_id, model_hash, size, archive_hash, tag, description
)
== content["hash"]
)
elif status_code == 403:
assert (
training.create_model_version(
mock_training_worker.create_model_version(
model_id, model_hash, size, archive_hash, tag, description
)
is None
)
def test_handle_s3_uploading_errors(model_file_dir):
training = TrainingWorker()
training.api_client = MockApiClient()
def test_handle_s3_uploading_errors(mock_training_worker, model_file_dir):
s3_endpoint_url = "http://s3.localhost.com"
responses.add_passthru(s3_endpoint_url)
responses.add(responses.Response(method="PUT", url=s3_endpoint_url, status=400))
file_path = model_file_dir / "model_file.pth"
with pytest.raises(Exception):
training.upload_to_s3(file_path, {"s3_put_url": s3_endpoint_url})
mock_training_worker.upload_to_s3(file_path, {"s3_put_url": s3_endpoint_url})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment