Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
B
Base Worker
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Workers
Base Worker
Commits
4561df33
Commit
4561df33
authored
2 years ago
by
Valentin Rigal
Committed by
Yoann Schneider
2 years ago
Browse files
Options
Downloads
Patches
Plain Diff
Support new model version API
parent
0a1c4722
Branches
new-training-worker-class
Branches containing commit
No related tags found
Tags containing commit
1 merge request
!287
Support new model version API
Pipeline
#80039
passed
2 years ago
Stage: release
Stage: deploy
Changes
2
Pipelines
11
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
arkindex_worker/worker/training.py
+190
-100
190 additions, 100 deletions
arkindex_worker/worker/training.py
tests/test_elements_worker/test_training.py
+176
-142
176 additions, 142 deletions
tests/test_elements_worker/test_training.py
with
366 additions
and
242 deletions
arkindex_worker/worker/training.py
+
190
−
100
View file @
4561df33
...
...
@@ -3,13 +3,15 @@
BaseWorker methods for training.
"""
import
functools
import
hashlib
import
os
import
tarfile
import
tempfile
from
contextlib
import
contextmanager
from
pathlib
import
Path
from
typing
import
NewType
,
Optional
,
Tuple
from
typing
import
NewType
,
Optional
,
Tuple
,
Union
from
uuid
import
UUID
import
requests
import
zstandard
as
zstd
...
...
@@ -91,7 +93,41 @@ def create_archive(path: DirPath) -> Tuple[Path, Hash, FileSize, Hash]:
os
.
remove
(
path_to_zst_archive
)
def
build_clean_payload
(
**
kwargs
):
"""
Remove null attributes from an API body payload
"""
return
{
key
:
value
for
key
,
value
in
kwargs
.
items
()
if
value
is
not
None
}
def
skip_if_read_only
(
func
):
"""
Return shortly in case the read_only attribute is evaluated to True
"""
@functools.wraps
(
func
)
def
wrapper
(
self
,
*
args
,
**
kwargs
):
if
getattr
(
self
,
"
read_only
"
,
False
):
logger
.
warning
(
"
Cannot perform this operation as the worker is in read-only mode
"
)
return
return
func
(
self
,
*
args
,
**
kwargs
)
return
wrapper
class
TrainingMixin
(
object
):
"""
A mixin helper to create a new model version easily.
You may use `publish_model_version` to publish a ready model version directly, or
separately create the model version then publish it (e.g to store training metrics).
Stores the currently handled model version as `self.model_version`.
"""
model_version
=
None
@skip_if_read_only
def
publish_model_version
(
self
,
model_path
:
DirPath
,
...
...
@@ -99,25 +135,44 @@ class TrainingMixin(object):
tag
:
Optional
[
str
]
=
None
,
description
:
Optional
[
str
]
=
None
,
configuration
:
Optional
[
dict
]
=
{},
parent
:
Optional
[
Union
[
str
,
UUID
]]
=
None
,
):
"""
This method creates a model archive and its associated hash,
to create a unique version that will be stored on a bucket and published on Arkindex.
Publish a unique version of a model in Arkindex, identified by its hash.
In case the `create_model_version` method has been called, reuses that model
instead of creating a new one.
:param model_path: Path to the directory containing the model version
'
s files.
:param model_id: ID of the model
:param tag: Tag of the model version
:param description: Description of the model version
:param configuration: Configuration of the model version
:param parent: ID of the parent model version
"""
if
self
.
is_read_only
:
logger
.
warning
(
"
Cannot publish a new model version as this worker is in read-only mode
"
if
not
self
.
model_version
:
self
.
create_model_version
(
model_id
=
model_id
,
tag
=
tag
,
description
=
description
,
configuration
=
configuration
,
parent
=
parent
,
)
elif
tag
or
description
or
configuration
or
parent
:
assert
(
self
.
model_version
.
get
(
"
model_id
"
)
==
model_id
),
"
Given `model_id` does not match the current model version
"
# If any attribute field has been defined, PATCH the current model version
self
.
update_model_version
(
tag
=
tag
,
description
=
description
,
configuration
=
configuration
,
parent
=
parent
,
)
return
# Create the zst archive, get its hash and size
# Validate the model version
with
create_archive
(
path
=
model_path
)
as
(
path_to_archive
,
hash
,
...
...
@@ -125,92 +180,110 @@ class TrainingMixin(object):
archive_hash
,
):
# Create a new model version with hash and size
model_version_details
=
self
.
create_model_version
(
model_id
=
model_id
,
hash
=
hash
,
self
.
upload_to_s3
(
archive_path
=
path_to_archive
)
current_version_id
=
self
.
model_version
[
"
id
"
]
# Mark the model as valid
self
.
validate_model_version
(
size
=
size
,
hash
=
hash
,
archive_hash
=
archive_hash
,
tag
=
tag
,
description
=
description
,
)
if
model_version_details
is
None
:
return
self
.
upload_to_s3
(
archive_path
=
path_to_archive
,
model_version_details
=
model_version_details
,
)
# Update the model version with state, configuration parsed, tag, description (defaults to name of the worker)
self
.
update_model_version
(
model_version_details
=
model_version_details
,
configuration
=
configuration
)
if
self
.
model_version
[
"
id
"
]
!=
current_version_id
and
(
tag
or
description
or
configuration
or
parent
):
logger
.
warning
(
"
Updating the existing available model version with the given attributes.
"
)
self
.
update_model_version
(
tag
=
tag
,
description
=
description
,
configuration
=
configuration
,
parent
=
parent
,
)
@skip_if_read_only
def
create_model_version
(
self
,
model_id
:
str
,
hash
:
str
,
size
:
int
,
archive_hash
:
str
,
tag
:
str
,
description
:
str
,
)
->
dict
:
tag
:
Optional
[
str
]
=
None
,
description
:
Optional
[
str
]
=
None
,
configuration
:
Optional
[
dict
]
=
{},
parent
:
Optional
[
Union
[
str
,
UUID
]]
=
None
,
):
"""
Create a new version of the specified model with the given information (hashes and size).
If a version matching the information already exist, there are two cases:
- The version is in `Created` state: this version
'
s details is used
- The version is in `Available` state: you cannot create twice the same version, an error is raised
Create a new version of the specified model with its base attributes.
Once successfully created, the model version is accessible via `self.model_version`.
:param tag: Tag of the model version
:param description: Description of the model version
:param configuration: Configuration of the model version
:param parent: ID of the parent model version
"""
if
self
.
is_read_only
:
logger
.
warning
(
"
Cannot create a new model version as this worker is in read-only mode
"
)
return
assert
not
self
.
model_version
,
"
A model version has already been created.
"
self
.
model_version
=
self
.
request
(
"
CreateModelVersion
"
,
id
=
model_id
,
body
=
build_clean_payload
(
tag
=
tag
,
description
=
description
,
configuration
=
configuration
,
parent
=
parent
,
),
)
logger
.
info
(
f
"
Model version (
{
self
.
model_version
[
'
id
'
]
}
) was successfully created
"
)
# Create a new model version with hash and size
try
:
payload
=
{
"
hash
"
:
hash
,
"
size
"
:
size
,
"
archive_hash
"
:
archive_hash
}
if
tag
:
payload
[
"
tag
"
]
=
tag
if
description
:
payload
[
"
description
"
]
=
description
model_version_details
=
self
.
request
(
"
CreateModelVersion
"
,
id
=
model_id
,
body
=
payload
,
)
logger
.
info
(
f
"
Model version (
{
model_version_details
[
'
id
'
]
}
) was created successfully
"
)
except
ErrorResponse
as
e
:
model_version_details
=
(
e
.
content
.
get
(
"
hash
"
)
if
hasattr
(
e
,
"
content
"
)
else
None
)
if
e
.
status_code
>=
500
or
model_version_details
is
None
:
logger
.
error
(
f
"
Failed to create model version:
{
e
.
content
}
"
)
raise
e
# If the existing model is in Created state, this model is returned as a dict.
# Else an error in a list is returned.
if
isinstance
(
model_version_details
,
(
list
,
tuple
)):
logger
.
error
(
model_version_details
[
0
])
return
logger
.
info
(
f
"
Model version (
{
model_version_details
[
'
id
'
]
}
) has the same hash, using this one instead of creating one
"
)
@skip_if_read_only
def
update_model_version
(
self
,
tag
:
Optional
[
str
]
=
None
,
description
:
Optional
[
str
]
=
None
,
configuration
:
Optional
[
dict
]
=
None
,
parent
:
Optional
[
Union
[
str
,
UUID
]]
=
None
,
):
"""
Update the current model version with the given attributes.
return
model_version_details
:param tag: Tag of the model version
:param description: Description of the model version
:param configuration: Configuration of the model version
:param parent: ID of the parent model version
"""
assert
self
.
model_version
,
"
No model version has been created yet.
"
self
.
model_version
=
self
.
request
(
"
UpdateModelVersion
"
,
id
=
self
.
model_version
[
"
id
"
],
body
=
build_clean_payload
(
tag
=
tag
,
description
=
description
,
configuration
=
configuration
,
parent
=
parent
,
),
)
logger
.
info
(
f
"
Model version (
{
self
.
model_version
[
'
id
'
]
}
) was successfully updated
"
)
def
upload_to_s3
(
self
,
archive_path
:
str
,
model_version_details
:
dict
)
->
None
:
@skip_if_read_only
def
upload_to_s3
(
self
,
archive_path
:
str
)
->
None
:
"""
Upload the archive of the model
'
s files to an Amazon s3 compatible storage
"""
if
self
.
is_read_only
:
logger
.
warning
(
"
Cannot upload this archive as this worker is in read-only mode
"
)
return
s3_put_url
=
model_version_details
.
get
(
"
s3_put_url
"
)
assert
(
self
.
model_version
),
"
You must create the model version before uploading an archive.
"
assert
(
self
.
model_version
[
"
state
"
]
!=
"
Available
"
),
"
The model is already marked as available.
"
s3_put_url
=
self
.
model_version
.
get
(
"
s3_put_url
"
)
assert
(
s3_put_url
),
"
S3 PUT URL is not set, please ensure you have the right to validate a model version.
"
logger
.
info
(
"
Uploading to s3...
"
)
# Upload the archive on s3
with
open
(
archive_path
,
"
rb
"
)
as
archive
:
...
...
@@ -221,33 +294,50 @@ class TrainingMixin(object):
)
r
.
raise_for_status
()
def
update_model_version
(
@skip_if_read_only
def
validate_model_version
(
self
,
model_version_details
:
dict
,
configuration
:
dict
,
)
->
None
:
"""
Update the specified model version to the state `Available` and use the given information
"
hash
:
str
,
size
:
int
,
archive_hash
:
str
,
):
"""
if
self
.
is_read_only
:
logger
.
warning
(
"
Cannot update this model version as this worker is in read-only mode
"
)
return
Sets the model version as `Available`, once its archive has been uploaded to S3.
model_version_id
=
model_version_details
.
get
(
"
id
"
)
logger
.
info
(
f
"
Updating model version (
{
model_version_id
}
)
"
)
:param hash: MD5 hash of the files contained in the archive
:param size: The size of the uploaded archive
:param archive_hash: MD5 hash of the uploaded archive
"""
assert
(
self
.
model_version
),
"
You must create the model version and upload its archive before validating it.
"
try
:
self
.
request
(
"
Up
dateModelVersion
"
,
id
=
model_version
_id
,
self
.
model_version
=
self
.
request
(
"
Vali
dateModelVersion
"
,
id
=
self
.
model_version
[
"
id
"
]
,
body
=
{
"
state
"
:
"
available
"
,
"
description
"
:
model_version_details
.
get
(
"
description
"
),
"
configuration
"
:
configuration
,
"
tag
"
:
model_version_details
.
get
(
"
tag
"
),
"
size
"
:
size
,
"
hash
"
:
hash
,
"
archive_hash
"
:
archive_hash
,
},
)
logger
.
info
(
f
"
Model version (
{
model_version_id
}
) was successfully updated
"
)
except
ErrorResponse
as
e
:
logger
.
error
(
f
"
Failed to update model version:
{
e
.
content
}
"
)
if
e
.
status_code
!=
409
:
raise
e
logger
.
warning
(
f
"
An available model version exists with hash
{
hash
}
, using it instead of the pending version.
"
)
pending_version_id
=
self
.
model_version
[
"
id
"
]
self
.
model_version
=
getattr
(
e
,
"
content
"
,
None
)
assert
self
.
model_version
is
not
None
,
"
An unexpected error occurred.
"
logger
.
warning
(
"
Removing the pending model version.
"
)
try
:
self
.
request
(
"
DestroyModelVersion
"
,
id
=
pending_version_id
)
except
ErrorResponse
as
e
:
msg
=
getattr
(
e
,
"
content
"
,
str
(
e
))
logger
.
error
(
f
"
An error occurred removing the pending version
{
pending_version_id
}
:
{
msg
}
.
"
)
logger
.
info
(
f
"
Model version
{
self
.
model_version
[
'
id
'
]
}
is now available.
"
)
This diff is collapsed.
Click to expand it.
tests/test_elements_worker/test_training.py
+
176
−
142
View file @
4561df33
# -*- coding: utf-8 -*-
import
logging
import
os
import
sys
...
...
@@ -10,16 +11,13 @@ from arkindex_worker.worker import BaseWorker
from
arkindex_worker.worker.training
import
TrainingMixin
,
create_archive
class
TrainingWorker
(
BaseWorker
,
TrainingMixin
):
"""
This class is only needed for tests
"""
pass
@pytest.fixture
def
mock_training_worker
(
monkeypatch
):
class
TrainingWorker
(
BaseWorker
,
TrainingMixin
):
"""
This class is needed to run tests in the context of a training worker
"""
monkeypatch
.
setattr
(
sys
,
"
argv
"
,
[
"
worker
"
])
training_worker
=
TrainingWorker
()
training_worker
.
api_client
=
MockApiClient
()
...
...
@@ -27,8 +25,27 @@ def mock_training_worker(monkeypatch):
return
training_worker
@pytest.fixture
def
default_model_version
():
return
{
"
id
"
:
"
model_version_id
"
,
"
model_id
"
:
"
model_id
"
,
"
state
"
:
"
created
"
,
"
parent
"
:
"
42
"
*
16
,
"
tag
"
:
"
A simple tag
"
,
"
description
"
:
"
A description
"
,
"
configuration
"
:
{
"
test
"
:
"
value
"
},
"
s3_url
"
:
None
,
"
s3_put_url
"
:
"
http://upload.archive
"
,
"
hash
"
:
None
,
"
archive_hash
"
:
None
,
"
size
"
:
None
,
"
created
"
:
"
2000-01-01T00:00:00Z
"
,
}
def
test_create_archive
(
model_file_dir
):
"""
Create an archive w
hen the model
'
s file is in a folder
"""
"""
Create an archive w
ith all base attributes
"""
with
create_archive
(
path
=
model_file_dir
)
as
(
zst_archive_path
,
...
...
@@ -63,155 +80,172 @@ def test_create_archive_with_subfolder(model_file_dir_with_subfolder):
assert
not
os
.
path
.
exists
(
zst_archive_path
),
"
Auto removal failed
"
def
test_handle_s3_uploading_errors
(
mock_training_worker
,
model_file_dir
):
s3_endpoint_url
=
"
http://s3.localhost.com
"
responses
.
add_passthru
(
s3_endpoint_url
)
responses
.
add
(
responses
.
Response
(
method
=
"
PUT
"
,
url
=
s3_endpoint_url
,
status
=
400
))
file_path
=
model_file_dir
/
"
model_file.pth
"
with
pytest
.
raises
(
Exception
):
mock_training_worker
.
upload_to_s3
(
file_path
,
{
"
s3_put_url
"
:
s3_endpoint_url
})
@pytest.mark.parametrize
(
"
tag, description
"
,
"
method
"
,
[
(
"
tag
"
,
"
description
"
),
(
None
,
"
description
"
),
(
""
,
"
description
"
),
(
"
tag
"
,
""
),
(
""
,
""
),
(
None
,
None
),
(
"
publish_model_version
"
),
(
"
create_model_version
"
),
(
"
update_model_version
"
),
(
"
upload_to_s3
"
),
(
"
validate_model_version
"
),
],
)
def
test_create_model_version
(
mock_training_worker
,
tag
,
description
):
"""
A new model version is returned
"""
model_version_id
=
"
fake_model_version_id
"
model_id
=
"
fake_model_id
"
model_hash
=
"
hash
"
archive_hash
=
"
archive_hash
"
size
=
"
30
"
model_version_details
=
{
"
id
"
:
model_version_id
,
"
model_id
"
:
model_id
,
"
hash
"
:
model_hash
,
"
archive_hash
"
:
archive_hash
,
"
size
"
:
size
,
"
tag
"
:
tag
,
"
description
"
:
description
,
"
s3_url
"
:
"
http://hehehe.com
"
,
"
s3_put_url
"
:
"
http://hehehe.com
"
,
}
def
test_training_mixin_read_only
(
mock_training_worker
,
method
,
caplog
):
"""
All operations related to models versions returns early if the worker is configured as read only
"""
mock_training_worker
.
read_only
=
True
assert
mock_training_worker
.
model_version
is
None
getattr
(
mock_training_worker
,
method
)()
assert
mock_training_worker
.
model_version
is
None
assert
[(
level
,
message
)
for
_
,
level
,
message
in
caplog
.
record_tuples
]
==
[
(
logging
.
WARNING
,
"
Cannot perform this operation as the worker is in read-only mode
"
,
),
]
def
test_create_model_version_already_created
(
mock_training_worker
):
mock_training_worker
.
model_version
=
{
"
id
"
:
"
model_version_id
"
}
with
pytest
.
raises
(
AssertionError
,
match
=
"
A model version has already been created.
"
):
mock_training_worker
.
create_model_version
(
model_id
=
"
model_id
"
)
expected_payload
=
{
"
hash
"
:
model_hash
,
"
archive_hash
"
:
archive_hash
,
"
size
"
:
size
,
}
if
description
:
expected_payload
[
"
description
"
]
=
description
if
tag
:
expected_payload
[
"
tag
"
]
=
tag
@pytest.mark.parametrize
(
"
set_tag
"
,
[
True
,
False
])
def
test_create_model_version
(
mock_training_worker
,
default_model_version
,
set_tag
):
args
=
{
"
parent
"
:
"
42
"
*
16
,
"
tag
"
:
"
A simple tag
"
,
"
description
"
:
"
A description
"
,
"
configuration
"
:
{
"
test
"
:
"
value
"
},
}
if
not
set_tag
:
del
args
[
"
tag
"
]
default_model_version
[
"
tag
"
]
=
None
mock_training_worker
.
api_client
.
add_response
(
"
CreateModelVersion
"
,
id
=
model_id
,
response
=
model_version_details
,
body
=
expected_payload
,
)
assert
(
mock_training_worker
.
create_model_version
(
model_id
,
model_hash
,
size
,
archive_hash
,
tag
,
description
)
==
model_version_details
id
=
"
model_id
"
,
response
=
default_model_version
,
body
=
args
,
)
assert
mock_training_worker
.
model_version
is
None
mock_training_worker
.
create_model_version
(
model_id
=
"
model_id
"
,
**
args
)
assert
mock_training_worker
.
model_version
==
default_model_version
@pytest.mark.parametrize
(
"
content, status_code
"
,
[
(
{
"
id
"
:
"
fake_model_version_id
"
,
"
model_id
"
:
"
fake_model_id
"
,
"
hash
"
:
"
hash
"
,
"
archive_hash
"
:
"
archive_hash
"
,
"
size
"
:
"
size
"
,
"
tag
"
:
"
tag
"
,
"
description
"
:
"
description
"
,
"
s3_url
"
:
"
http://hehehe.com
"
,
"
s3_put_url
"
:
"
http://hehehe.com
"
,
},
400
,
),
([
"
A version for this model with this hash already exists.
"
],
403
),
],
)
def
test_retrieve_created_model_version
(
mock_training_worker
,
content
,
status_code
):
"""
If there is an existing model version in Created mode,
A 400 was raised, but the model is still returned in error content.
Else if an existing model version in Available mode,
403 was raised, but None will be returned
"""
model_id
=
"
fake_model_id
"
model_hash
=
"
hash
"
archive_hash
=
"
archive_hash
"
size
=
"
30
"
mock_training_worker
.
api_client
.
add_error_response
(
"
CreateModelVersion
"
,
id
=
model_id
,
status_code
=
status_code
,
body
=
{
"
hash
"
:
model_hash
,
"
archive_hash
"
:
archive_hash
,
"
size
"
:
size
,
},
content
=
{
"
hash
"
:
content
},
def
test_update_model_version_not_created
(
mock_training_worker
):
with
pytest
.
raises
(
AssertionError
,
match
=
"
No model version has been created yet.
"
):
mock_training_worker
.
update_model_version
()
def
test_update_model_version
(
mock_training_worker
,
default_model_version
):
mock_training_worker
.
model_version
=
default_model_version
args
=
{
"
tag
"
:
"
A new tag
"
}
new_model_version
=
{
**
default_model_version
,
"
tag
"
:
"
A new tag
"
}
mock_training_worker
.
api_client
.
add_response
(
"
UpdateModelVersion
"
,
id
=
"
model_version_id
"
,
response
=
new_model_version
,
body
=
args
,
)
if
status_code
==
400
:
assert
(
mock_training_worker
.
create_model_version
(
model_id
,
model_hash
,
size
,
archive_hash
,
tag
=
None
,
description
=
None
)
==
content
)
elif
status_code
==
403
:
assert
(
mock_training_worker
.
create_model_version
(
model_id
,
model_hash
,
size
,
archive_hash
,
tag
=
None
,
description
=
None
)
is
None
)
mock_training_worker
.
update_model_version
(
**
args
)
assert
mock_training_worker
.
model_version
==
new_model_version
@pytest.mark.parametrize
(
"
content, status_code
"
,
(
# error 500
({
"
id
"
:
"
fake_id
"
},
500
),
# model_version details is None
({},
403
),
(
None
,
403
),
),
)
def
test_handle_500_create_model_version
(
mock_training_worker
,
content
,
status_code
):
model_id
=
"
fake_model_id
"
model_hash
=
"
hash
"
archive_hash
=
"
archive_hash
"
size
=
"
30
"
def
test_validate_model_version_not_created
(
mock_training_worker
):
with
pytest
.
raises
(
AssertionError
,
match
=
"
You must create the model version and upload its archive before validating it.
"
,
):
mock_training_worker
.
validate_model_version
(
hash
=
"
a
"
,
size
=
1
,
archive_hash
=
"
b
"
)
@pytest.mark.parametrize
(
"
deletion_failed
"
,
[
True
,
False
])
def
test_validate_model_version_hash_conflict
(
mock_training_worker
,
default_model_version
,
caplog
,
deletion_failed
):
mock_training_worker
.
model_version
=
{
"
id
"
:
"
another_id
"
}
args
=
{
"
hash
"
:
"
hash
"
,
"
archive_hash
"
:
"
archive_hash
"
,
"
size
"
:
30
,
}
mock_training_worker
.
api_client
.
add_error_response
(
"
CreateModelVersion
"
,
id
=
model_id
,
status_code
=
status_code
,
body
=
{
"
hash
"
:
model_hash
,
"
archive_hash
"
:
archive_hash
,
"
size
"
:
size
,
},
content
=
content
,
"
ValidateModelVersion
"
,
id
=
"
another_id
"
,
status_code
=
409
,
body
=
args
,
content
=
default_model_version
,
)
with
pytest
.
raises
(
Exception
):
mock_training_worker
.
create_model_version
(
model_id
,
model_hash
,
size
,
archive_hash
,
tag
=
None
,
description
=
None
if
deletion_failed
:
mock_training_worker
.
api_client
.
add_error_response
(
"
DestroyModelVersion
"
,
id
=
"
another_id
"
,
status_code
=
403
,
content
=
"
Not admin
"
,
)
else
:
mock_training_worker
.
api_client
.
add_response
(
"
DestroyModelVersion
"
,
id
=
"
another_id
"
,
response
=
"
No content
"
,
)
mock_training_worker
.
validate_model_version
(
**
args
)
assert
mock_training_worker
.
model_version
==
default_model_version
error_msg
=
[]
if
deletion_failed
:
error_msg
=
[
(
logging
.
ERROR
,
"
An error occurred removing the pending version another_id: Not admin.
"
,
)
]
assert
[
(
level
,
message
)
for
module
,
level
,
message
in
caplog
.
record_tuples
if
module
==
"
arkindex_worker
"
]
==
[
(
logging
.
WARNING
,
"
An available model version exists with hash hash, using it instead of the pending version.
"
,
),
(
logging
.
WARNING
,
"
Removing the pending model version.
"
),
*
error_msg
,
(
logging
.
INFO
,
"
Model version model_version_id is now available.
"
),
]
def
test_validate_model_version
(
mock_training_worker
,
default_model_version
,
caplog
):
mock_training_worker
.
model_version
=
{
"
id
"
:
"
model_version_id
"
}
args
=
{
"
hash
"
:
"
hash
"
,
"
archive_hash
"
:
"
archive_hash
"
,
"
size
"
:
30
,
}
mock_training_worker
.
api_client
.
add_response
(
"
ValidateModelVersion
"
,
id
=
"
model_version_id
"
,
body
=
args
,
response
=
default_model_version
,
)
def
test_handle_s3_uploading_errors
(
mock_training_worker
,
model_file_dir
):
s3_endpoint_url
=
"
http://s3.localhost.com
"
responses
.
add_passthru
(
s3_endpoint_url
)
responses
.
add
(
responses
.
Response
(
method
=
"
PUT
"
,
url
=
s3_endpoint_url
,
status
=
400
))
file_path
=
model_file_dir
/
"
model_file.pth
"
with
pytest
.
raises
(
Exception
):
mock_training_worker
.
upload_to_s3
(
file_path
,
{
"
s3_put_url
"
:
s3_endpoint_url
})
mock_training_worker
.
validate_model_version
(
**
args
)
assert
mock_training_worker
.
model_version
==
default_model_version
assert
[
(
level
,
message
)
for
module
,
level
,
message
in
caplog
.
record_tuples
if
module
==
"
arkindex_worker
"
]
==
[
(
logging
.
INFO
,
"
Model version model_version_id is now available.
"
),
]
This diff is collapsed.
Click to expand it.
Yoann Schneider
@yschneider
mentioned in issue
pylaia#30 (closed)
·
2 years ago
mentioned in issue
pylaia#30 (closed)
mentioned in issue pylaia#30
Toggle commit list
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment