Skip to content

Commit

Permalink
Improvement of the tuner documentation (aws#4506)
Browse files Browse the repository at this point in the history
  • Loading branch information
repushko authored and jiapinw committed Jun 25, 2024
1 parent e0cccd7 commit c51907e
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""

from __future__ import absolute_import

import importlib
Expand Down Expand Up @@ -641,8 +642,11 @@ def __init__(
extract the metric from the logs. This should be defined only
for hyperparameter tuning jobs that don't use an Amazon
algorithm.
strategy (str or PipelineVariable): Strategy to be used for hyperparameter estimations
(default: 'Bayesian').
strategy (str or PipelineVariable): Strategy to be used for hyperparameter estimations.
More information about different strategies:
https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-how-it-works.html.
Available options are: 'Bayesian', 'Random', 'Hyperband',
'Grid' (default: 'Bayesian')
objective_type (str or PipelineVariable): The type of the objective metric for
evaluating training jobs. This value can be either 'Minimize' or
'Maximize' (default: 'Maximize').
Expand Down Expand Up @@ -759,7 +763,8 @@ def __init__(
self.autotune = autotune

def override_resource_config(
self, instance_configs: Union[List[InstanceConfig], Dict[str, List[InstanceConfig]]]
self,
instance_configs: Union[List[InstanceConfig], Dict[str, List[InstanceConfig]]],
):
"""Override the instance configuration of the estimators used by the tuner.
Expand Down Expand Up @@ -966,7 +971,7 @@ def fit(
include_cls_metadata: Union[bool, Dict[str, bool]] = False,
estimator_kwargs: Optional[Dict[str, dict]] = None,
wait: bool = True,
**kwargs
**kwargs,
):
"""Start a hyperparameter tuning job.
Expand Down Expand Up @@ -1055,7 +1060,7 @@ def _fit_with_estimator_dict(self, inputs, job_name, include_cls_metadata, estim
allowed_keys=estimator_names,
)

for (estimator_name, estimator) in self.estimator_dict.items():
for estimator_name, estimator in self.estimator_dict.items():
ins = inputs.get(estimator_name, None) if inputs is not None else None
args = estimator_kwargs.get(estimator_name, {}) if estimator_kwargs is not None else {}
self._prepare_estimator_for_tuning(estimator, ins, job_name, **args)
Expand Down Expand Up @@ -1282,7 +1287,7 @@ def _attach_with_training_details_list(cls, sagemaker_session, estimator_cls, jo
objective_metric_name_dict=objective_metric_name_dict,
hyperparameter_ranges_dict=hyperparameter_ranges_dict,
metric_definitions_dict=metric_definitions_dict,
**init_params
**init_params,
)

def deploy(
Expand All @@ -1297,7 +1302,7 @@ def deploy(
model_name=None,
kms_key=None,
data_capture_config=None,
**kwargs
**kwargs,
):
"""Deploy the best trained or user specified model to an Amazon SageMaker endpoint.
Expand Down Expand Up @@ -1363,7 +1368,7 @@ def deploy(
model_name=model_name,
kms_key=kms_key,
data_capture_config=data_capture_config,
**kwargs
**kwargs,
)

def stop_tuning_job(self):
Expand Down

0 comments on commit c51907e

Please sign in to comment.