diff --git a/arkindex_worker/worker/training.py b/arkindex_worker/worker/training.py index bacd1dde8f451206ff68169e0364191cd17b4387..37dbc30796fb6c55e0c27fb35d6b5ba8918bc801 100644 --- a/arkindex_worker/worker/training.py +++ b/arkindex_worker/worker/training.py @@ -98,6 +98,7 @@ class TrainingMixin(object): model_id: str, tag: Optional[str] = None, description: Optional[str] = None, + configuration: Optional[dict] = {}, ): """ This method creates a model archive and its associated hash, @@ -105,8 +106,9 @@ class TrainingMixin(object): :param model_path: Path to the directory containing the model version's files. :param model_id: ID of the model - :param tag: - :param description: + :param tag: Tag of the model version + :param description: Description of the model version + :param configuration: Configuration of the model version """ if self.is_read_only: @@ -140,7 +142,7 @@ class TrainingMixin(object): # Update the model version with state, configuration parsed, tag, description (defaults to name of the worker) self.update_model_version( - model_version_details=model_version_details, + model_version_details=model_version_details, configuration=configuration ) def create_model_version( @@ -222,7 +224,7 @@ class TrainingMixin(object): def update_model_version( self, model_version_details: dict, - configuration: dict = {}, + configuration: dict, ) -> None: """ Update the specified model version to the state `Available` and use the given information"