Skip to content
Snippets Groups Projects
Commit 4bd8ff64 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Bastien Abadie
Browse files

Disable training methods usage in read only mode

parent 76796f15
No related branches found
No related tags found
1 merge request!207Disable training methods usage in read only mode
Pipeline #79543 passed
...@@ -87,6 +87,11 @@ class TrainingMixin(object): ...@@ -87,6 +87,11 @@ class TrainingMixin(object):
This method creates a model archive and its associated hash, This method creates a model archive and its associated hash,
to create a unique version that will be stored on a bucket and published on arkindex. to create a unique version that will be stored on a bucket and published on arkindex.
""" """
if self.is_read_only:
logger.warning(
"Cannot publish a new model version as this worker is in read-only mode"
)
return
# Create the zst archive, get its hash and size # Create the zst archive, get its hash and size
with create_archive(path=model_path) as ( with create_archive(path=model_path) as (
...@@ -131,6 +136,11 @@ class TrainingMixin(object): ...@@ -131,6 +136,11 @@ class TrainingMixin(object):
- The version is in `Created` state: this version's details is used - The version is in `Created` state: this version's details is used
- The version is in `Available` state: you cannot create twice the same version, an error is raised - The version is in `Available` state: you cannot create twice the same version, an error is raised
""" """
if self.is_read_only:
logger.warning(
"Cannot create a new model version as this worker is in read-only mode"
)
return
# Create a new model version with hash and size # Create a new model version with hash and size
try: try:
...@@ -171,6 +181,11 @@ class TrainingMixin(object): ...@@ -171,6 +181,11 @@ class TrainingMixin(object):
""" """
Upload the archive of the model's files to an Amazon s3 compatible storage Upload the archive of the model's files to an Amazon s3 compatible storage
""" """
if self.is_read_only:
logger.warning(
"Cannot upload this archive as this worker is in read-only mode"
)
return
s3_put_url = model_version_details.get("s3_put_url") s3_put_url = model_version_details.get("s3_put_url")
logger.info("Uploading to s3...") logger.info("Uploading to s3...")
...@@ -191,6 +206,11 @@ class TrainingMixin(object): ...@@ -191,6 +206,11 @@ class TrainingMixin(object):
""" """
Update the specified model version to the state `Available` and use the given information" Update the specified model version to the state `Available` and use the given information"
""" """
if self.is_read_only:
logger.warning(
"Cannot update this model version as this worker is in read-only mode"
)
return
model_version_id = model_version_details.get("id") model_version_id = model_version_details.get("id")
logger.info(f"Updating model version ({model_version_id})") logger.info(f"Updating model version ({model_version_id})")
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os import os
import sys
import pytest import pytest
import responses import responses
...@@ -17,6 +18,15 @@ class TrainingWorker(BaseWorker, TrainingMixin): ...@@ -17,6 +18,15 @@ class TrainingWorker(BaseWorker, TrainingMixin):
pass pass
@pytest.fixture
def mock_training_worker(monkeypatch):
monkeypatch.setattr(sys, "argv", ["worker"])
training_worker = TrainingWorker()
training_worker.api_client = MockApiClient()
training_worker.args = training_worker.parser.parse_args()
return training_worker
def test_create_archive(model_file_dir): def test_create_archive(model_file_dir):
"""Create an archive when the model's file is in a folder""" """Create an archive when the model's file is in a folder"""
...@@ -46,13 +56,11 @@ def test_create_archive(model_file_dir): ...@@ -46,13 +56,11 @@ def test_create_archive(model_file_dir):
(None, ""), (None, ""),
], ],
) )
def test_create_model_version(tag, description): def test_create_model_version(mock_training_worker, tag, description):
"""A new model version is returned""" """A new model version is returned"""
model_version_id = "fake_model_version_id" model_version_id = "fake_model_version_id"
model_id = "fake_model_id" model_id = "fake_model_id"
training = TrainingWorker()
training.api_client = MockApiClient()
model_hash = "hash" model_hash = "hash"
archive_hash = "archive_hash" archive_hash = "archive_hash"
size = "30" size = "30"
...@@ -68,7 +76,7 @@ def test_create_model_version(tag, description): ...@@ -68,7 +76,7 @@ def test_create_model_version(tag, description):
"s3_put_url": "http://hehehe.com", "s3_put_url": "http://hehehe.com",
} }
training.api_client.add_response( mock_training_worker.api_client.add_response(
"CreateModelVersion", "CreateModelVersion",
id=model_id, id=model_id,
response=model_version_details, response=model_version_details,
...@@ -81,7 +89,7 @@ def test_create_model_version(tag, description): ...@@ -81,7 +89,7 @@ def test_create_model_version(tag, description):
}, },
) )
assert ( assert (
training.create_model_version( mock_training_worker.create_model_version(
model_id, model_hash, size, archive_hash, tag, description model_id, model_hash, size, archive_hash, tag, description
) )
== model_version_details == model_version_details
...@@ -126,7 +134,7 @@ def test_create_model_version(tag, description): ...@@ -126,7 +134,7 @@ def test_create_model_version(tag, description):
({"hash": ["A version for this model with this hash already exists."]}, 403), ({"hash": ["A version for this model with this hash already exists."]}, 403),
], ],
) )
def test_retrieve_created_model_version(content, status_code): def test_retrieve_created_model_version(mock_training_worker, content, status_code):
""" """
If there is an existing model version in Created mode, If there is an existing model version in Created mode,
A 400 was raised, but the model is still returned in error content. A 400 was raised, but the model is still returned in error content.
...@@ -135,14 +143,12 @@ def test_retrieve_created_model_version(content, status_code): ...@@ -135,14 +143,12 @@ def test_retrieve_created_model_version(content, status_code):
""" """
model_id = "fake_model_id" model_id = "fake_model_id"
training = TrainingWorker()
training.api_client = MockApiClient()
model_hash = "hash" model_hash = "hash"
archive_hash = "archive_hash" archive_hash = "archive_hash"
size = "30" size = "30"
tag = "tag" tag = "tag"
description = "description" description = "description"
training.api_client.add_error_response( mock_training_worker.api_client.add_error_response(
"CreateModelVersion", "CreateModelVersion",
id=model_id, id=model_id,
status_code=status_code, status_code=status_code,
...@@ -157,26 +163,24 @@ def test_retrieve_created_model_version(content, status_code): ...@@ -157,26 +163,24 @@ def test_retrieve_created_model_version(content, status_code):
) )
if status_code == 400: if status_code == 400:
assert ( assert (
training.create_model_version( mock_training_worker.create_model_version(
model_id, model_hash, size, archive_hash, tag, description model_id, model_hash, size, archive_hash, tag, description
) )
== content["hash"] == content["hash"]
) )
elif status_code == 403: elif status_code == 403:
assert ( assert (
training.create_model_version( mock_training_worker.create_model_version(
model_id, model_hash, size, archive_hash, tag, description model_id, model_hash, size, archive_hash, tag, description
) )
is None is None
) )
def test_handle_s3_uploading_errors(model_file_dir): def test_handle_s3_uploading_errors(mock_training_worker, model_file_dir):
training = TrainingWorker()
training.api_client = MockApiClient()
s3_endpoint_url = "http://s3.localhost.com" s3_endpoint_url = "http://s3.localhost.com"
responses.add_passthru(s3_endpoint_url) responses.add_passthru(s3_endpoint_url)
responses.add(responses.Response(method="PUT", url=s3_endpoint_url, status=400)) responses.add(responses.Response(method="PUT", url=s3_endpoint_url, status=400))
file_path = model_file_dir / "model_file.pth" file_path = model_file_dir / "model_file.pth"
with pytest.raises(Exception): with pytest.raises(Exception):
training.upload_to_s3(file_path, {"s3_put_url": s3_endpoint_url}) mock_training_worker.upload_to_s3(file_path, {"s3_put_url": s3_endpoint_url})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment