From c51907e807e507d2525a806837716d7e426d0167 Mon Sep 17 00:00:00 2001 From: Anton Repushko Date: Fri, 15 Mar 2024 16:49:29 +0100 Subject: [PATCH] Improvement of the tuner documentation (#4506) --- src/sagemaker/tuner.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 571f84761f..967bff1b99 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Placeholder docstring""" + from __future__ import absolute_import import importlib @@ -641,8 +642,11 @@ def __init__( extract the metric from the logs. This should be defined only for hyperparameter tuning jobs that don't use an Amazon algorithm. - strategy (str or PipelineVariable): Strategy to be used for hyperparameter estimations - (default: 'Bayesian'). + strategy (str or PipelineVariable): Strategy to be used for hyperparameter estimations. + More information about different strategies: + https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-how-it-works.html. + Available options are: 'Bayesian', 'Random', 'Hyperband', + 'Grid' (default: 'Bayesian') objective_type (str or PipelineVariable): The type of the objective metric for evaluating training jobs. This value can be either 'Minimize' or 'Maximize' (default: 'Maximize'). @@ -759,7 +763,8 @@ def __init__( self.autotune = autotune def override_resource_config( - self, instance_configs: Union[List[InstanceConfig], Dict[str, List[InstanceConfig]]] + self, + instance_configs: Union[List[InstanceConfig], Dict[str, List[InstanceConfig]]], ): """Override the instance configuration of the estimators used by the tuner. @@ -966,7 +971,7 @@ def fit( include_cls_metadata: Union[bool, Dict[str, bool]] = False, estimator_kwargs: Optional[Dict[str, dict]] = None, wait: bool = True, - **kwargs + **kwargs, ): """Start a hyperparameter tuning job. @@ -1055,7 +1060,7 @@ def _fit_with_estimator_dict(self, inputs, job_name, include_cls_metadata, estim allowed_keys=estimator_names, ) - for (estimator_name, estimator) in self.estimator_dict.items(): + for estimator_name, estimator in self.estimator_dict.items(): ins = inputs.get(estimator_name, None) if inputs is not None else None args = estimator_kwargs.get(estimator_name, {}) if estimator_kwargs is not None else {} self._prepare_estimator_for_tuning(estimator, ins, job_name, **args) @@ -1282,7 +1287,7 @@ def _attach_with_training_details_list(cls, sagemaker_session, estimator_cls, jo objective_metric_name_dict=objective_metric_name_dict, hyperparameter_ranges_dict=hyperparameter_ranges_dict, metric_definitions_dict=metric_definitions_dict, - **init_params + **init_params, ) def deploy( @@ -1297,7 +1302,7 @@ def deploy( model_name=None, kms_key=None, data_capture_config=None, - **kwargs + **kwargs, ): """Deploy the best trained or user specified model to an Amazon SageMaker endpoint. @@ -1363,7 +1368,7 @@ def deploy( model_name=model_name, kms_key=kms_key, data_capture_config=data_capture_config, - **kwargs + **kwargs, ) def stop_tuning_job(self):