From 0fb849f11bf4b2ffd149725a29ce7ae900339b48 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 1 Mar 2024 18:01:22 +0000 Subject: [PATCH 1/6] fix: sagemaker session region not being used --- .../artifacts/environment_variables.py | 5 +- .../jumpstart/artifacts/hyperparameters.py | 3 +- .../jumpstart/artifacts/image_uris.py | 3 +- .../artifacts/incremental_training.py | 3 +- .../jumpstart/artifacts/instance_types.py | 5 +- src/sagemaker/jumpstart/artifacts/kwargs.py | 9 +- .../jumpstart/artifacts/metric_definitions.py | 3 +- .../jumpstart/artifacts/model_packages.py | 5 +- .../jumpstart/artifacts/model_uris.py | 5 +- src/sagemaker/jumpstart/artifacts/payloads.py | 3 +- .../jumpstart/artifacts/predictors.py | 9 +- .../jumpstart/artifacts/resource_names.py | 3 +- .../artifacts/resource_requirements.py | 3 +- .../jumpstart/artifacts/script_uris.py | 5 +- src/sagemaker/jumpstart/cache.py | 106 +++++++++++------- src/sagemaker/jumpstart/estimator.py | 2 +- src/sagemaker/jumpstart/factory/estimator.py | 8 +- src/sagemaker/jumpstart/factory/model.py | 6 +- src/sagemaker/jumpstart/model.py | 2 +- src/sagemaker/jumpstart/utils.py | 5 +- src/sagemaker/jumpstart/validators.py | 2 +- .../jumpstart/test_accept_types.py | 3 +- .../jumpstart/test_content_types.py | 3 +- .../jumpstart/test_deserializers.py | 3 +- .../jumpstart/test_default.py | 3 +- .../hyperparameters/jumpstart/test_default.py | 3 +- .../jumpstart/test_validate.py | 3 +- .../image_uris/jumpstart/test_common.py | 3 +- .../jumpstart/test_instance_types.py | 5 +- .../jumpstart/estimator/test_estimator.py | 3 +- .../sagemaker/jumpstart/model/test_model.py | 10 +- .../sagemaker/jumpstart/test_artifacts.py | 6 +- .../jumpstart/test_default.py | 6 +- .../model_uris/jumpstart/test_common.py | 3 +- .../jumpstart/test_resource_requirements.py | 4 +- .../script_uris/jumpstart/test_common.py | 3 +- .../serializers/jumpstart/test_serializers.py | 5 +- 37 files changed, 148 insertions(+), 113 deletions(-) diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index 0e666e4c14..f664a00fcb 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -15,7 +15,6 @@ from typing import Dict, Optional from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, ) from sagemaker.jumpstart.enums import ( @@ -72,7 +71,7 @@ def _retrieve_default_environment_variables( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -169,7 +168,7 @@ def _retrieve_gated_model_uri_env_var_value( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index e9e6f613f8..d9112f715f 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -15,7 +15,6 @@ from typing import Dict, Optional from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, @@ -71,7 +70,7 @@ def _retrieve_default_hyperparameters( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 6ea1ca84a1..9f2c300b5e 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -17,7 +17,6 @@ from sagemaker import image_uris from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, @@ -105,7 +104,7 @@ def _retrieve_image_uri( DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index 753a911422..26c6076e1d 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -15,7 +15,6 @@ from typing import Optional from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, @@ -59,7 +58,7 @@ def _model_supports_incremental_training( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 38e02e3ebd..bc52041f3b 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -18,7 +18,6 @@ from sagemaker.jumpstart.exceptions import NO_AVAILABLE_INSTANCES_ERROR_MSG from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, @@ -75,7 +74,7 @@ def _retrieve_default_instance_type( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -161,7 +160,7 @@ def _retrieve_instance_types( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index 7acad9b793..44bb195466 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -18,7 +18,6 @@ from sagemaker.utils import volume_size_supported from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, @@ -61,7 +60,7 @@ def _retrieve_model_init_kwargs( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -118,7 +117,7 @@ def _retrieve_model_deploy_kwargs( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -172,7 +171,7 @@ def _retrieve_estimator_init_kwargs( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -229,7 +228,7 @@ def _retrieve_estimator_fit_kwargs( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index b6f6019641..f39f40cf33 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -16,7 +16,6 @@ from typing import Dict, List, Optional from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, @@ -63,7 +62,7 @@ def _retrieve_default_training_metric_definitions( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index bd0ae365d9..a518fb1393 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -15,7 +15,6 @@ from typing import Optional from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -64,7 +63,7 @@ def _retrieve_model_package_arn( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -147,7 +146,7 @@ def _retrieve_model_package_model_artifact_s3_uri( if scope == JumpStartScriptScope.TRAINING: if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index c41f0a75b7..6cf26b0baa 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -18,7 +18,6 @@ from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, @@ -130,7 +129,7 @@ def _retrieve_model_uri( DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -207,7 +206,7 @@ def _model_supports_training_model_uri( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 3ea2c16f80..412e6f4445 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -16,7 +16,6 @@ from typing import Dict, Optional from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, @@ -62,7 +61,7 @@ def _retrieve_example_payloads( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 8d599c89cc..f741ca92cc 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -20,7 +20,6 @@ CONTENT_TYPE_TO_SERIALIZER_TYPE_MAP, DEFAULT_JUMPSTART_SAGEMAKER_SESSION, DESERIALIZER_TYPE_TO_CLASS_MAP, - JUMPSTART_DEFAULT_REGION_NAME, SERIALIZER_TYPE_TO_CLASS_MAP, ) from sagemaker.jumpstart.enums import ( @@ -302,7 +301,7 @@ def _retrieve_default_content_type( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -350,7 +349,7 @@ def _retrieve_default_accept_type( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -399,7 +398,7 @@ def _retrieve_supported_accept_types( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -448,7 +447,7 @@ def _retrieve_supported_content_types( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index 6b05f07b15..b088c576dd 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -15,7 +15,6 @@ from typing import Optional from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, @@ -59,7 +58,7 @@ def _retrieve_resource_name_base( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 8356d1efac..a6fe884d1b 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -17,7 +17,6 @@ from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, @@ -71,7 +70,7 @@ def _retrieve_default_resources( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index c1b037ce61..7db4cd016d 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -17,7 +17,6 @@ from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, @@ -72,7 +71,7 @@ def _retrieve_script_uri( DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -133,7 +132,7 @@ def _model_supports_inference_script_uri( """ if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index e26d588167..a5ad3aaf4f 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -15,7 +15,7 @@ import datetime from difflib import get_close_matches import os -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Set, Tuple, Union import json import boto3 import botocore @@ -27,6 +27,7 @@ JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, + JUMPSTART_REGION_NAME_SET, MODEL_ID_LIST_WEB_URL, ) from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg @@ -55,23 +56,22 @@ class JumpStartModelsCache: for launching JumpStart models from the SageMaker SDK. """ - # fmt: off def __init__( self, - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, max_s3_cache_items: int = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, - s3_cache_expiration_horizon: datetime.timedelta = - JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON, - max_semantic_version_cache_items: int = - JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, - semantic_version_cache_expiration_horizon: datetime.timedelta = - JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, - manifest_file_s3_key: str = - JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, + s3_cache_expiration_horizon: datetime.timedelta = ( + JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON + ), + max_semantic_version_cache_items: int = JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, + semantic_version_cache_expiration_horizon: datetime.timedelta = ( + JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON + ), + manifest_file_s3_key: str = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, s3_bucket_name: Optional[str] = None, s3_client_config: Optional[botocore.config.Config] = None, s3_client: Optional[boto3.client] = None, - ) -> None: # fmt: on + ) -> None: """Initialize a ``JumpStartModelsCache`` instance. Args: @@ -94,7 +94,10 @@ def __init__( s3_client (Optional[boto3.client]): s3 client to use. Default: None. """ - self._region = region + self._region = region or self._get_region_fallback( + s3_bucket_name=s3_bucket_name, s3_client=s3_client + ) + self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue]( max_cache_items=max_s3_cache_items, expiration_horizon=s3_cache_expiration_horizon, @@ -119,6 +122,35 @@ def __init__( else boto3.client("s3", region_name=self._region) ) + def _get_region_fallback( + self, s3_bucket_name: Optional[str], s3_client: Optional[boto3.client] + ) -> str: + """Returns region to use throughout cache in the absence of one specified in constructor.""" + regions_in_s3_bucket_name: Set[str] = { + region + for region in JUMPSTART_REGION_NAME_SET + if s3_bucket_name is not None + if region in s3_bucket_name + } + regions_in_s3_client_endpoint_url: Set[str] = { + region + for region in JUMPSTART_REGION_NAME_SET + if s3_client is not None + if region in s3_client._endpoint.host + } + + combined_regions = regions_in_s3_client_endpoint_url.union(regions_in_s3_bucket_name) + + if len(combined_regions) > 1: + raise ValueError( + "Unable to resolve a region name from the s3 bucket and client provided." + ) + + if len(combined_regions) == 0: + return JUMPSTART_DEFAULT_REGION_NAME + + return list(combined_regions)[0] + def set_region(self, region: str) -> None: """Set region for cache. Clears cache after new region is set.""" if region != self._region: @@ -192,7 +224,8 @@ def _get_manifest_key_from_model_id_semantic_version( return JumpStartVersionedModelId(model_id, sm_compatible_model_version) versions_incompatible_with_sagemaker = [ - Version(header.version) for header in manifest.values() # type: ignore + Version(header.version) + for header in manifest.values() # type: ignore if header.model_id == model_id ] sm_incompatible_model_version = self._select_version( @@ -222,17 +255,14 @@ def _get_manifest_key_from_model_id_semantic_version( raise KeyError(error_msg) error_msg = f"Unable to find model manifest for '{model_id}' with version '{version}'. " - error_msg += ( - f"Visit {MODEL_ID_LIST_WEB_URL} for updated list of models. " - ) + error_msg += f"Visit {MODEL_ID_LIST_WEB_URL} for updated list of models. " other_model_id_version = self._select_version( "*", versions_incompatible_with_sagemaker ) # all versions here are incompatible with sagemaker if other_model_id_version is not None: error_msg += ( - f"Consider using model ID '{model_id}' with version " - f"'{other_model_id_version}'." + f"Consider using model ID '{model_id}' with version " f"'{other_model_id_version}'." ) else: @@ -249,15 +279,15 @@ def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list], def _is_local_metadata_mode(self) -> bool: """Returns True if the cache should use local metadata mode, based off env variables.""" - return (ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ - and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE]) - and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ - and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE])) + return ( + ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ + and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE]) + and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ + and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE]) + ) def _get_json_file( - self, - key: str, - filetype: JumpStartS3FileType + self, key: str, filetype: JumpStartS3FileType ) -> Tuple[Union[dict, list], Optional[str]]: """Returns json file either from s3 or local file system. @@ -281,21 +311,19 @@ def _get_json_md5_hash(self, key: str): return self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=key)["ETag"] def _get_json_file_from_local_override( - self, - key: str, - filetype: JumpStartS3FileType + self, key: str, filetype: JumpStartS3FileType ) -> Union[dict, list]: """Reads json file from local filesystem and returns data.""" if filetype == JumpStartS3FileType.MANIFEST: - metadata_local_root = ( - os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE] - ) + metadata_local_root = os.environ[ + ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE + ] elif filetype == JumpStartS3FileType.SPECS: metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE] else: raise ValueError(f"Unsupported file type for local override: {filetype}") file_path = os.path.join(metadata_local_root, key) - with open(file_path, 'r') as f: + with open(file_path, "r") as f: data = json.load(f) return data @@ -333,9 +361,7 @@ def _retrieval_function( formatted_body, _ = self._get_json_file(s3_key, file_type) model_specs = JumpStartModelSpecs(formatted_body) utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client) - return JumpStartCachedS3ContentValue( - formatted_content=model_specs - ) + return JumpStartCachedS3ContentValue(formatted_content=model_specs) raise ValueError( f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}" ) @@ -382,9 +408,7 @@ def _select_version( except InvalidSpecifier: raise KeyError(f"Bad semantic version: {semantic_version_str}") available_versions_filtered = list(spec.filter(available_versions)) - return ( - str(max(available_versions_filtered)) if available_versions_filtered != [] else None - ) + return str(max(available_versions_filtered)) if available_versions_filtered != [] else None def _get_header_impl( self, @@ -436,9 +460,7 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS if not cache_hit and "*" in semantic_version_str: JUMPSTART_LOGGER.warning( get_wildcard_model_version_msg( - header.model_id, - semantic_version_str, - header.version + header.model_id, semantic_version_str, header.version ) ) return specs.formatted_content diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 24105c4369..bbeff4a645 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -508,7 +508,7 @@ def _is_valid_model_id_hook(): return is_valid_model_id( model_id=model_id, model_version=model_version, - region=region, + region=region or getattr(sagemaker_session, "boto_region_name", None), script=JumpStartScriptScope.TRAINING, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 7ccf57983b..144978106f 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -189,8 +189,8 @@ def get_init_kwargs( estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(estimator_init_kwargs) - estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_sagemaker_session_to_kwargs(estimator_init_kwargs) + estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_instance_type_and_count_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_image_uri_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_model_uri_to_kwargs(estimator_init_kwargs) @@ -390,7 +390,9 @@ def get_deploy_kwargs( def _add_region_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: """Sets region in kwargs based on default or override, returns full kwargs.""" - kwargs.region = kwargs.region or JUMPSTART_DEFAULT_REGION_NAME + kwargs.region = ( + kwargs.region or kwargs.sagemaker_session.boto_region_name or JUMPSTART_DEFAULT_REGION_NAME + ) return kwargs @@ -504,6 +506,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + region=kwargs.region, instance_type=kwargs.instance_type, ) @@ -550,6 +553,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart model_version=kwargs.model_version, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + region=kwargs.region, sagemaker_session=kwargs.sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 64e4727116..0fd512cada 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -124,7 +124,9 @@ def get_default_predictor( def _add_region_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets region kwargs based on default or override, returns full kwargs.""" - kwargs.region = kwargs.region or JUMPSTART_DEFAULT_REGION_NAME + kwargs.region = ( + kwargs.region or kwargs.sagemaker_session.boto_region_name or JUMPSTART_DEFAULT_REGION_NAME + ) return kwargs @@ -715,8 +717,8 @@ def get_init_kwargs( model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_sagemaker_session_to_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 1742f860e4..b4e1fc24d6 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -274,7 +274,7 @@ def _is_valid_model_id_hook(): return is_valid_model_id( model_id=model_id, model_version=model_version, - region=region, + region=region or getattr(sagemaker_session, "boto_region_name", None), script=JumpStartScriptScope.INFERENCE, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 2621422811..66d5a3590d 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -526,7 +526,7 @@ def verify_model_region_and_return_specs( model_id: Optional[str], version: Optional[str], scope: Optional[str], - region: str, + region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -561,6 +561,9 @@ def verify_model_region_and_return_specs( DeprecatedJumpStartModelError: If the version of the model is deprecated. """ + if region is None: + region = sagemaker_session.boto_region_name + if scope is None: raise ValueError( "Must specify `model_scope` argument to retrieve model " diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index 3199e5fc2e..e41f9e835c 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -206,7 +206,7 @@ def validate_hyperparameters( validation_mode = HyperparameterValidationMode.VALIDATE_PROVIDED if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = sagemaker_session.boto_region_name model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py index 28211d06f1..e73c5f47aa 100644 --- a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py +++ b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py @@ -21,7 +21,8 @@ from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +region = "us-west-2" +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") diff --git a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py index 4b2db7d7f4..194394899f 100644 --- a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py +++ b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py @@ -21,7 +21,8 @@ from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +region = "us-west-2" +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") diff --git a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py index 9d6e2f21de..2be31ed536 100644 --- a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py +++ b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py @@ -21,9 +21,10 @@ from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec +region = "us-west-2" mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") diff --git a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py index acd8d19923..d118872a24 100644 --- a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py +++ b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py @@ -22,7 +22,8 @@ from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +region = "us-west-2" +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py index eebc079164..12fbab48fa 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py @@ -23,7 +23,8 @@ mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +region = "us-west-2" +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index 0054ed9dbd..c8f553ad2e 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -23,8 +23,9 @@ from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec +region = "us-west-2" mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index 8a41891280..8862bf092e 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -32,8 +32,9 @@ def test_jumpstart_common_image_uri( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec + region = "us-west-2" mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + mock_session = Mock(s3_client=mock_client, boto_region_name=region) image_uris.retrieve( framework=None, diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index bed2e50674..5fded4c03b 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -31,7 +31,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): region = "us-west-2" mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + mock_session = Mock(s3_client=mock_client, boto_region_name=region) default_training_instance_types = instance_types.retrieve_default( region=region, @@ -168,7 +168,8 @@ def test_jumpstart_inference_instance_type_variants(patched_get_model_specs): patched_get_model_specs.side_effect = get_special_model_spec mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) model_id, model_version = "inference-instance-types-variant-model", "*" region = "us-west-2" diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 4dc35b65ca..bc0c9a7cec 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1567,7 +1567,8 @@ def test_training_passes_session_to_deploy( mock_get_model_specs.side_effect = get_special_model_spec mock_role = f"dsfsdfsd{time.time()}" - mock_sagemaker_session = mock.MagicMock(sagemaker_config={}) + region = "us-west-2" + mock_sagemaker_session = mock.MagicMock(sagemaker_config={}, boto_region_name=region) mock_sagemaker_session.get_caller_identity_arn = lambda: mock_role estimator = JumpStartEstimator( diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index f45283935b..0826b76184 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -888,7 +888,7 @@ def test_jumpstart_model_tags( mock_get_model_specs.side_effect = get_special_model_spec - mock_session = MagicMock(sagemaker_config={}) + mock_session = MagicMock(sagemaker_config={}, boto_region_name="us-west-2") model = JumpStartModel(model_id=model_id, sagemaker_session=mock_session) @@ -925,7 +925,9 @@ def test_jumpstart_model_tags_disabled( mock_get_model_specs.side_effect = get_special_model_spec settings = SessionSettings(include_jumpstart_tags=False) - mock_session = MagicMock(sagemaker_config={}, settings=settings) + mock_session = MagicMock( + sagemaker_config={}, settings=settings, boto_region_name="us-west-2" + ) model = JumpStartModel(model_id=model_id, sagemaker_session=mock_session) @@ -956,7 +958,7 @@ def test_jumpstart_model_package_arn( mock_get_model_specs.side_effect = get_special_model_spec - mock_session = MagicMock(sagemaker_config={}) + mock_session = MagicMock(sagemaker_config={}, boto_region_name="us-west-2") model = JumpStartModel(model_id=model_id, sagemaker_session=mock_session) @@ -991,7 +993,7 @@ def test_jumpstart_model_package_arn_override( mock_get_model_specs.side_effect = get_special_model_spec - mock_session = MagicMock(sagemaker_config={}) + mock_session = MagicMock(sagemaker_config={}, boto_region_name="us-west-2") model_package_arn = ( "arn:aws:sagemaker:us-west-2:867530986753:model-package/" diff --git a/tests/unit/sagemaker/jumpstart/test_artifacts.py b/tests/unit/sagemaker/jumpstart/test_artifacts.py index 1a770f785f..8a7b68cf51 100644 --- a/tests/unit/sagemaker/jumpstart/test_artifacts.py +++ b/tests/unit/sagemaker/jumpstart/test_artifacts.py @@ -329,7 +329,8 @@ def test_estimator_fit_kwargs(self, patched_get_model_specs): class RetrieveModelPackageArnTest(unittest.TestCase): - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_retrieve_model_package_arn(self, patched_get_model_specs): @@ -435,7 +436,8 @@ def test_retrieve_model_package_arn(self, patched_get_model_specs): class PrivateJumpStartBucketTest(unittest.TestCase): - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_retrieve_uri_from_gated_bucket(self, patched_get_model_specs): diff --git a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py index ffc6000c91..8eeae0b3a2 100644 --- a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py +++ b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py @@ -22,7 +22,8 @@ from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +region = "us-west-2" +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -31,7 +32,8 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): patched_get_model_specs.side_effect = get_spec_from_base_spec mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) model_id = "pytorch-ic-mobilenet-v2" region = "us-west-2" diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 000540e12e..603c86fa13 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -34,7 +34,8 @@ def test_jumpstart_common_model_uri( patched_get_model_specs.side_effect = get_spec_from_base_spec mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) model_uris.retrieve( model_scope="training", diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index 28b53270f8..6cc5312f61 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -28,7 +28,7 @@ def test_jumpstart_resource_requirements(patched_get_model_specs): patched_get_model_specs.side_effect = get_spec_from_base_spec region = "us-west-2" mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + mock_session = Mock(s3_client=mock_client, boto_region_name=region) model_id, model_version = "huggingface-llm-mistral-7b-instruct", "*" default_inference_resource_requirements = resource_requirements.retrieve_default( @@ -57,7 +57,7 @@ def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs): model_id, model_version = "no-supported-instance-types-model", "*" region = "us-west-2" mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + mock_session = Mock(s3_client=mock_client, boto_region_name=region) default_inference_resource_requirements = resource_requirements.retrieve_default( region=region, diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 3f38326608..5abcc8ab8c 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -34,7 +34,8 @@ def test_jumpstart_common_script_uri( patched_get_model_specs.side_effect = get_spec_from_base_spec mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) script_uris.retrieve( script_scope="training", diff --git a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py index b22b61dc40..5b573403cb 100644 --- a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py +++ b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py @@ -35,7 +35,7 @@ def test_jumpstart_default_serializers( model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + mock_session = Mock(s3_client=mock_client, boto_region_name=region) default_serializer = serializers.retrieve_default( region=region, @@ -65,7 +65,8 @@ def test_jumpstart_serializer_options( patched_get_model_specs.side_effect = get_special_model_spec mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" From 18d388c9202dc61cd877fed1b08749d3d5265510 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 4 Mar 2024 16:00:20 +0000 Subject: [PATCH 2/6] chore: add unit tests --- .../jumpstart/estimator/test_estimator.py | 58 ++++++++++++++++++- .../sagemaker/jumpstart/model/test_model.py | 55 +++++++++++++++++- tests/unit/sagemaker/jumpstart/test_cache.py | 36 ++++++++++++ 3 files changed, 146 insertions(+), 3 deletions(-) diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index bc0c9a7cec..9318ba9c64 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -16,6 +16,7 @@ from unittest import mock import unittest from inspect import signature +from mock import Mock import pytest from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig @@ -30,12 +31,16 @@ from sagemaker.jumpstart.artifacts.metric_definitions import ( _retrieve_default_training_metric_definitions, ) -from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + JUMPSTART_DEFAULT_REGION_NAME, +) from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag from sagemaker.jumpstart.estimator import JumpStartEstimator from sagemaker.jumpstart.utils import get_jumpstart_content_bucket +from sagemaker.session import Session from sagemaker.session_settings import SessionSettings from tests.integ.sagemaker.jumpstart.utils import get_training_dataset_for_model_and_version from sagemaker.model import Model @@ -44,6 +49,7 @@ get_special_model_spec, overwrite_dictionary, ) +import boto3 execution_role = "fake role! do not use!" @@ -1773,6 +1779,56 @@ def test_model_artifact_variant_estimator( ], ) + @mock.patch("sagemaker.jumpstart.estimator.get_default_predictor") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_jumpstart_estimator_session( + self, + mock_get_model_specs: mock.Mock, + mock_is_valid_model_id: mock.Mock, + mock_deploy, + mock_fit, + mock_init, + get_default_predictor, + ): + + mock_is_valid_model_id.return_value = True + + model_id, _ = "js-trainable-model", "*" + + mock_get_model_specs.side_effect = get_special_model_spec + + region = "eu-west-1" # some non-default region + + if region == JUMPSTART_DEFAULT_REGION_NAME: + region = "us-west-2" + + session = Session(boto_session=boto3.session.Session(region_name=region)) + + assert session.boto_region_name != JUMPSTART_DEFAULT_REGION_NAME + + session.get_caller_identity_arn = Mock(return_value="blah") + + estimator = JumpStartEstimator(model_id=model_id, sagemaker_session=session) + estimator.fit() + + estimator.deploy() + + assert len(mock_get_model_specs.call_args_list) > 1 + + regions = {call[1]["region"] for call in mock_get_model_specs.call_args_list} + + assert len(regions) == 1 + assert list(regions)[0] == region + + s3_clients = {call[1]["s3_client"] for call in mock_get_model_specs.call_args_list} + assert len(s3_clients) == 1 + assert list(s3_clients)[0] == session.s3_client + def test_jumpstart_estimator_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 0826b76184..8c9508023b 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -15,18 +15,22 @@ from typing import Optional, Set from unittest import mock import unittest -from mock import MagicMock +from mock import MagicMock, Mock import pytest from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.jumpstart.artifacts.environment_variables import ( _retrieve_default_environment_variables, ) -from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + JUMPSTART_DEFAULT_REGION_NAME, +) from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag from sagemaker.jumpstart.model import JumpStartModel from sagemaker.model import Model from sagemaker.predictor import Predictor +from sagemaker.session import Session from sagemaker.session_settings import SessionSettings from sagemaker.enums import EndpointType from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements @@ -36,6 +40,7 @@ overwrite_dictionary, get_special_model_spec_for_inference_component_based_endpoint, ) +import boto3 execution_role = "fake role! do not use!" region = "us-west-2" @@ -1252,6 +1257,52 @@ def test_model_registry_accept_and_response_types( response_types=["application/json;verbose", "application/json"], ) + @mock.patch("sagemaker.jumpstart.model.get_default_predictor") + @mock.patch("sagemaker.jumpstart.model.Model.__init__") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_jumpstart_model_session( + self, + mock_get_model_specs: mock.Mock, + mock_is_valid_model_id: mock.Mock, + mock_deploy, + mock_init, + get_default_predictor, + ): + + mock_is_valid_model_id.return_value = True + + model_id, _ = "model_data_s3_prefix_model", "*" + + mock_get_model_specs.side_effect = get_special_model_spec + + region = "eu-west-1" # some non-default region + + if region == JUMPSTART_DEFAULT_REGION_NAME: + region = "us-west-2" + + session = Session(boto_session=boto3.session.Session(region_name=region)) + + assert session.boto_region_name != JUMPSTART_DEFAULT_REGION_NAME + + session.get_caller_identity_arn = Mock(return_value="blah") + + model = JumpStartModel(model_id=model_id, sagemaker_session=session) + model.deploy() + + assert len(mock_get_model_specs.call_args_list) > 1 + + regions = {call[1]["region"] for call in mock_get_model_specs.call_args_list} + + assert len(regions) == 1 + assert list(regions)[0] == region + + s3_clients = {call[1]["s3_client"] for call in mock_get_model_specs.call_args_list} + assert len(s3_clients) == 1 + assert list(s3_clients)[0] == session.s3_client + def test_jumpstart_model_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 6633ecdc23..47c88802fa 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -18,6 +18,7 @@ from unittest.mock import Mock, call, mock_open from botocore.stub import Stubber import botocore +import boto3 from mock.mock import MagicMock import pytest @@ -27,6 +28,7 @@ from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, + JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.types import ( JumpStartModelHeader, @@ -854,3 +856,37 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( ), ] ) + + +@pytest.mark.parametrize( + "s3_bucket_name, s3_client, region", + [ + ( + "jumpstart-cache-prod", + boto3.client("s3", region_name="blah-blah"), + JUMPSTART_DEFAULT_REGION_NAME, + ), + ( + "jumpstart-cache-prod-us-west-2", + boto3.client("s3", region_name="us-west-2"), + "us-west-2", + ), + ("jumpstart-cache-prod", boto3.client("s3", region_name="us-east-2"), "us-east-2"), + ], +) +def test_get_region_fallback_success(s3_bucket_name, s3_client, region): + cache = JumpStartModelsCache() + assert region == cache._get_region_fallback(s3_bucket_name, s3_client) + + +@pytest.mark.parametrize( + "s3_bucket_name, s3_client", + [ + ("jumpstart-cache-prod-us-west-2", boto3.client("s3", region_name="us-east-2")), + ("jumpstart-cache-prod-us-west-2-us-east-2", boto3.client("s3", region_name="us-east-2")), + ], +) +def test_get_region_fallback_failure(s3_bucket_name, s3_client): + cache = JumpStartModelsCache() + with pytest.raises(ValueError): + cache._get_region_fallback(s3_bucket_name, s3_client) From 7279028b3b8b32f6c6828cf85469e03aac069025 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 12 Mar 2024 19:03:52 +0000 Subject: [PATCH 3/6] fix: remove all JUMPSTART_DEFAULT_REGION_NAME default arguments --- src/sagemaker/jumpstart/cache.py | 35 +------------- src/sagemaker/jumpstart/notebook_utils.py | 45 ++++++++++++------ src/sagemaker/jumpstart/payload_utils.py | 8 ++-- src/sagemaker/jumpstart/utils.py | 41 +++++++++++++++- src/sagemaker/jumpstart/validators.py | 11 +++-- tests/unit/sagemaker/jumpstart/test_cache.py | 36 -------------- tests/unit/sagemaker/jumpstart/test_utils.py | 49 ++++++++++++++++++++ 7 files changed, 134 insertions(+), 91 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index a5ad3aaf4f..f870f0ca1a 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -15,7 +15,7 @@ import datetime from difflib import get_close_matches import os -from typing import List, Optional, Set, Tuple, Union +from typing import List, Optional, Tuple, Union import json import boto3 import botocore @@ -25,9 +25,7 @@ ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, - JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, - JUMPSTART_REGION_NAME_SET, MODEL_ID_LIST_WEB_URL, ) from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg @@ -94,7 +92,7 @@ def __init__( s3_client (Optional[boto3.client]): s3 client to use. Default: None. """ - self._region = region or self._get_region_fallback( + self._region = region or utils.get_region_fallback( s3_bucket_name=s3_bucket_name, s3_client=s3_client ) @@ -122,35 +120,6 @@ def __init__( else boto3.client("s3", region_name=self._region) ) - def _get_region_fallback( - self, s3_bucket_name: Optional[str], s3_client: Optional[boto3.client] - ) -> str: - """Returns region to use throughout cache in the absence of one specified in constructor.""" - regions_in_s3_bucket_name: Set[str] = { - region - for region in JUMPSTART_REGION_NAME_SET - if s3_bucket_name is not None - if region in s3_bucket_name - } - regions_in_s3_client_endpoint_url: Set[str] = { - region - for region in JUMPSTART_REGION_NAME_SET - if s3_client is not None - if region in s3_client._endpoint.host - } - - combined_regions = regions_in_s3_client_endpoint_url.union(regions_in_s3_bucket_name) - - if len(combined_regions) > 1: - raise ValueError( - "Unable to resolve a region name from the s3 bucket and client provided." - ) - - if len(combined_regions) == 0: - return JUMPSTART_DEFAULT_REGION_NAME - - return list(combined_regions)[0] - def set_region(self, region: str) -> None: """Set region for cache. Clears cache after new region is set.""" if region != self._region: diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 1554025995..5e61ccced4 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -23,7 +23,6 @@ from sagemaker.jumpstart import accessors from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.filters import ( @@ -36,6 +35,7 @@ from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs from sagemaker.jumpstart.utils import ( get_jumpstart_content_bucket, + get_region_fallback, get_sagemaker_version, verify_model_region_and_return_specs, ) @@ -143,7 +143,7 @@ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]: def list_jumpstart_tasks( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """List tasks for JumpStart, and optionally apply filters to result. @@ -155,11 +155,14 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin (e.g. ``"task == ic"``). If this argument is not supplied, all tasks will be listed. (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding - models. (Default: JUMPSTART_DEFAULT_REGION_NAME). + models. (Default: None). sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) tasks: Set[str] = set() for model_id, _ in _generate_jumpstart_model_versions( filter=filter, region=region, sagemaker_session=sagemaker_session @@ -171,7 +174,7 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin def list_jumpstart_frameworks( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """List frameworks for JumpStart, and optionally apply filters to result. @@ -183,11 +186,14 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin (eg. ``"task == ic"``). If this argument is not supplied, all frameworks will be listed. (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding - models. (Default: JUMPSTART_DEFAULT_REGION_NAME). + models. (Default: None). sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) frameworks: Set[str] = set() for model_id, _ in _generate_jumpstart_model_versions( filter=filter, region=region, sagemaker_session=sagemaker_session @@ -199,7 +205,7 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin def list_jumpstart_scripts( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """List scripts for JumpStart, and optionally apply filters to result. @@ -211,10 +217,13 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin (e.g. ``"task == ic"``). If this argument is not supplied, all scripts will be listed. (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding - models. (Default: JUMPSTART_DEFAULT_REGION_NAME). + models. (Default: None). sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) if (isinstance(filter, Constant) and filter.resolved_value == BooleanValues.TRUE) or ( isinstance(filter, str) and filter.lower() == BooleanValues.TRUE.lower() ): @@ -242,7 +251,7 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin def list_jumpstart_models( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, list_incomplete_models: bool = False, list_old_models: bool = False, list_versions: bool = False, @@ -257,7 +266,7 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin (e.g. ``"task == ic"``). If this argument is not supplied, all models will be listed. (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding - models. (Default: JUMPSTART_DEFAULT_REGION_NAME). + models. (Default: None). list_incomplete_models (bool): Optional. If a model does not contain metadata fields requested by the filter, and the filter cannot be resolved to a include/not include, whether the model should be included. By default, these models are omitted from results. @@ -270,6 +279,9 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_id_version_dict: Dict[str, List[str]] = dict() for model_id, version in _generate_jumpstart_model_versions( filter=filter, @@ -299,7 +311,7 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, list_incomplete_models: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Generator: @@ -312,7 +324,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin (e.g. ``"task == ic"``). If this argument is not supplied, all models will be generated. (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding - models. (Default: JUMPSTART_DEFAULT_REGION_NAME). + models. (Default: None). list_incomplete_models (bool): Optional. If a model does not contain metadata fields requested by the filter, and the filter cannot be resolved to a include/not include, whether the model should be included. By default, these models are omitted from @@ -321,6 +333,10 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) + models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( region=region, s3_client=sagemaker_session.s3_client ) @@ -453,7 +469,7 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, def get_model_url( model_id: str, model_version: str, - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieve web url describing pretrained model. @@ -462,11 +478,14 @@ def get_model_url( model_id (str): The model ID for which to retrieve the url. model_version (str): The model version for which to retrieve the url. region (str): Optional. The region from which to retrieve metadata. - (Default: JUMPSTART_DEFAULT_REGION_NAME) + (Default: None) sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to retrieve the model url. """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, diff --git a/src/sagemaker/jumpstart/payload_utils.py b/src/sagemaker/jumpstart/payload_utils.py index 242118c56e..595f801598 100644 --- a/src/sagemaker/jumpstart/payload_utils.py +++ b/src/sagemaker/jumpstart/payload_utils.py @@ -22,12 +22,12 @@ from sagemaker.jumpstart.artifacts.payloads import _retrieve_example_payloads from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import MIMEType from sagemaker.jumpstart.types import JumpStartSerializablePayload from sagemaker.jumpstart.utils import ( get_jumpstart_content_bucket, + get_region_fallback, ) from sagemaker.session import Session @@ -125,12 +125,14 @@ class PayloadSerializer: def __init__( self, bucket: Optional[str] = None, - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, s3_client: Optional[boto3.client] = None, ) -> None: """Initializes PayloadSerializer object.""" self.bucket = bucket or get_jumpstart_content_bucket() - self.region = region + self.region = region or get_region_fallback( + s3_client=s3_client, + ) self.s3_client = s3_client def get_bytes_payload_with_s3_references( diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 66d5a3590d..27af763f5f 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging import os -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union from urllib.parse import urlparse import boto3 from packaging.version import Version @@ -813,3 +813,42 @@ def get_jumpstart_model_id_version_from_resource_arn( model_version = model_version_from_tag return model_id, model_version + + +def get_region_fallback( + s3_bucket_name: Optional[str] = None, + s3_client: Optional[boto3.client] = None, + sagemaker_session: Optional[Session] = None, +) -> str: + """Returns region to use for JumpStart functionality implicitly via session objects.""" + regions_in_s3_bucket_name: Set[str] = { + region + for region in constants.JUMPSTART_REGION_NAME_SET + if s3_bucket_name is not None + if region in s3_bucket_name + } + regions_in_s3_client_endpoint_url: Set[str] = { + region + for region in constants.JUMPSTART_REGION_NAME_SET + if s3_client is not None + if region in s3_client._endpoint.host + } + + regions_in_sagemaker_session: Set[str] = { + region + for region in constants.JUMPSTART_REGION_NAME_SET + if sagemaker_session + if region == sagemaker_session.boto_region_name + } + + combined_regions = regions_in_s3_client_endpoint_url.union( + regions_in_s3_bucket_name, regions_in_sagemaker_session + ) + + if len(combined_regions) > 1: + raise ValueError("Unable to resolve a region name from the s3 bucket and client provided.") + + if len(combined_regions) == 0: + return constants.JUMPSTART_DEFAULT_REGION_NAME + + return list(combined_regions)[0] diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index e41f9e835c..f3df507f65 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -14,7 +14,6 @@ from __future__ import absolute_import from typing import Any, Dict, List, Optional from sagemaker import session -from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.jumpstart.enums import ( HyperparameterValidationMode, @@ -24,7 +23,7 @@ ) from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError from sagemaker.jumpstart.types import JumpStartHyperparameter -from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.utils import get_region_fallback, verify_model_region_and_return_specs def _validate_hyperparameter( @@ -168,7 +167,7 @@ def validate_hyperparameters( model_version: str, hyperparameters: Dict[str, Any], validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, - region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, sagemaker_session: Optional[session.Session] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -184,8 +183,7 @@ def validate_hyperparameters( to this function will be validated, the missing hyperparameters will be ignored. If set to``VALIDATE_ALGORITHM``, all algorithm hyperparameters will be validated. If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated. - region (str): Region for which to validate hyperparameters. (Default: JumpStart - default region). + region (str): Region for which to validate hyperparameters. (Default: None). sagemaker_session (Optional[Session]): Custom SageMaker Session to use. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -202,6 +200,9 @@ def validate_hyperparameters( """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) if validation_mode is None: validation_mode = HyperparameterValidationMode.VALIDATE_PROVIDED diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 47c88802fa..6633ecdc23 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -18,7 +18,6 @@ from unittest.mock import Mock, call, mock_open from botocore.stub import Stubber import botocore -import boto3 from mock.mock import MagicMock import pytest @@ -28,7 +27,6 @@ from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.types import ( JumpStartModelHeader, @@ -856,37 +854,3 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( ), ] ) - - -@pytest.mark.parametrize( - "s3_bucket_name, s3_client, region", - [ - ( - "jumpstart-cache-prod", - boto3.client("s3", region_name="blah-blah"), - JUMPSTART_DEFAULT_REGION_NAME, - ), - ( - "jumpstart-cache-prod-us-west-2", - boto3.client("s3", region_name="us-west-2"), - "us-west-2", - ), - ("jumpstart-cache-prod", boto3.client("s3", region_name="us-east-2"), "us-east-2"), - ], -) -def test_get_region_fallback_success(s3_bucket_name, s3_client, region): - cache = JumpStartModelsCache() - assert region == cache._get_region_fallback(s3_bucket_name, s3_client) - - -@pytest.mark.parametrize( - "s3_bucket_name, s3_client", - [ - ("jumpstart-cache-prod-us-west-2", boto3.client("s3", region_name="us-east-2")), - ("jumpstart-cache-prod-us-west-2-us-east-2", boto3.client("s3", region_name="us-east-2")), - ], -) -def test_get_region_fallback_failure(s3_bucket_name, s3_client): - cache = JumpStartModelsCache() - with pytest.raises(ValueError): - cache._get_region_fallback(s3_bucket_name, s3_client) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 556b99bc9c..90978350fc 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -15,7 +15,9 @@ from unittest import TestCase from mock.mock import Mock, patch import pytest +import boto3 import random +from sagemaker import session from sagemaker.jumpstart import utils from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -1426,3 +1428,50 @@ def test_logger_disabled(self, mocked_emit: Mock): JUMPSTART_LOGGER.warning("Self destruct in 3...2...1...") mocked_emit.assert_not_called() + + +@pytest.mark.parametrize( + "s3_bucket_name, s3_client, sagemaker_session, region", + [ + ( + "jumpstart-cache-prod", + boto3.client("s3", region_name="blah-blah"), + session.Session(boto3.Session(region_name="blah-blah")), + JUMPSTART_DEFAULT_REGION_NAME, + ), + ( + "jumpstart-cache-prod-us-west-2", + boto3.client("s3", region_name="us-west-2"), + session.Session(boto3.Session(region_name="us-west-2")), + "us-west-2", + ), + ("jumpstart-cache-prod", boto3.client("s3", region_name="us-east-2"), None, "us-east-2"), + ], +) +def test_get_region_fallback_success(s3_bucket_name, s3_client, sagemaker_session, region): + assert region == utils.get_region_fallback(s3_bucket_name, s3_client, sagemaker_session) + + +@pytest.mark.parametrize( + "s3_bucket_name, s3_client, sagemaker_session", + [ + ( + "jumpstart-cache-prod-us-west-2", + boto3.client("s3", region_name="us-east-2"), + session.Session(boto3.Session(region_name="us-west-2")), + ), + ( + "jumpstart-cache-prod-us-west-2", + boto3.client("s3", region_name="us-west-2"), + session.Session(boto3.Session(region_name="eu-north-1")), + ), + ( + "jumpstart-cache-prod-us-west-2-us-east-2", + boto3.client("s3", region_name="us-east-2"), + None, + ), + ], +) +def test_get_region_fallback_failure(s3_bucket_name, s3_client, sagemaker_session): + with pytest.raises(ValueError): + utils.get_region_fallback(s3_bucket_name, s3_client, sagemaker_session) From 0e450e84adf76b458f10a87cca7750b80a4d33a7 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 12 Mar 2024 22:19:48 +0000 Subject: [PATCH 4/6] chore: use get_region_fallback throughout --- .../artifacts/environment_variables.py | 9 +++++++-- .../jumpstart/artifacts/hyperparameters.py | 5 ++++- src/sagemaker/jumpstart/artifacts/image_uris.py | 5 ++++- .../jumpstart/artifacts/incremental_training.py | 5 ++++- .../jumpstart/artifacts/instance_types.py | 9 +++++++-- src/sagemaker/jumpstart/artifacts/kwargs.py | 17 +++++++++++++---- .../jumpstart/artifacts/metric_definitions.py | 5 ++++- .../jumpstart/artifacts/model_packages.py | 9 +++++++-- src/sagemaker/jumpstart/artifacts/model_uris.py | 9 +++++++-- src/sagemaker/jumpstart/artifacts/payloads.py | 5 ++++- src/sagemaker/jumpstart/artifacts/predictors.py | 17 +++++++++++++---- .../jumpstart/artifacts/resource_names.py | 5 ++++- .../artifacts/resource_requirements.py | 5 ++++- .../jumpstart/artifacts/script_uris.py | 9 +++++++-- src/sagemaker/jumpstart/utils.py | 4 +++- src/sagemaker/jumpstart/validators.py | 4 +++- 16 files changed, 95 insertions(+), 27 deletions(-) diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index f664a00fcb..c5b1893597 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -22,6 +22,7 @@ ) from sagemaker.jumpstart.utils import ( get_jumpstart_gated_content_bucket, + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -71,7 +72,9 @@ def _retrieve_default_environment_variables( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -168,7 +171,9 @@ def _retrieve_gated_model_uri_env_var_value( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index d9112f715f..e22d9d27dd 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -21,6 +21,7 @@ VariableScope, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -70,7 +71,9 @@ def _retrieve_default_hyperparameters( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 9f2c300b5e..d162f7f9f5 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -23,6 +23,7 @@ ModelFramework, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -104,7 +105,9 @@ def _retrieve_image_uri( DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index 26c6076e1d..8a70671362 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -20,6 +20,7 @@ JumpStartScriptScope, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -58,7 +59,9 @@ def _model_supports_incremental_training( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 30f48a2300..5f79243d7e 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -24,6 +24,7 @@ JumpStartModelType, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -76,7 +77,9 @@ def _retrieve_default_instance_type( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -163,7 +166,9 @@ def _retrieve_instance_types( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index c289625192..2246f0638a 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -24,6 +24,7 @@ JumpStartModelType, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) @@ -62,7 +63,9 @@ def _retrieve_model_init_kwargs( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -121,7 +124,9 @@ def _retrieve_model_deploy_kwargs( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -176,7 +181,9 @@ def _retrieve_estimator_init_kwargs( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -233,7 +240,9 @@ def _retrieve_estimator_fit_kwargs( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index f39f40cf33..4bbef49620 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -21,6 +21,7 @@ JumpStartScriptScope, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -62,7 +63,9 @@ def _retrieve_default_training_metric_definitions( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index d12b2f573f..64ed038afb 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -17,6 +17,7 @@ DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.jumpstart.enums import ( @@ -65,7 +66,9 @@ def _retrieve_model_package_arn( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -149,7 +152,9 @@ def _retrieve_model_package_model_artifact_s3_uri( if scope == JumpStartScriptScope.TRAINING: if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index 6cf26b0baa..5a5bede37a 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -25,6 +25,7 @@ from sagemaker.jumpstart.utils import ( get_jumpstart_content_bucket, get_jumpstart_gated_content_bucket, + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -129,7 +130,9 @@ def _retrieve_model_uri( DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -206,7 +209,9 @@ def _model_supports_training_model_uri( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index d97dc50bdd..6ed64de47d 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -23,6 +23,7 @@ ) from sagemaker.jumpstart.types import JumpStartSerializablePayload from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -63,7 +64,9 @@ def _retrieve_example_payloads( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 1cda33b8dc..2f352999c8 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -28,6 +28,7 @@ JumpStartModelType, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -309,7 +310,9 @@ def _retrieve_default_content_type( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -359,7 +362,9 @@ def _retrieve_default_accept_type( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -410,7 +415,9 @@ def _retrieve_supported_accept_types( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -461,7 +468,9 @@ def _retrieve_supported_content_types( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index 978eff4961..db392b93f4 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -21,6 +21,7 @@ JumpStartModelType, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -60,7 +61,9 @@ def _retrieve_resource_name_base( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index e28e0c1761..212c558465 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -23,6 +23,7 @@ JumpStartModelType, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -89,7 +90,9 @@ def _retrieve_default_resources( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index 7db4cd016d..cf53ac31dd 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -23,6 +23,7 @@ ) from sagemaker.jumpstart.utils import ( get_jumpstart_content_bucket, + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -71,7 +72,9 @@ def _retrieve_script_uri( DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -132,7 +135,9 @@ def _model_supports_inference_script_uri( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index e5cf45e502..69e60ba35e 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -573,7 +573,9 @@ def verify_model_region_and_return_specs( """ if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) if scope is None: raise ValueError( diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index f3df507f65..65932ae245 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -207,7 +207,9 @@ def validate_hyperparameters( validation_mode = HyperparameterValidationMode.VALIDATE_PROVIDED if region is None: - region = sagemaker_session.boto_region_name + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, From 6f7dde96a96284fad82ccaf1160f6e3a10707cf2 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 12 Mar 2024 23:04:28 +0000 Subject: [PATCH 5/6] chore: remove unnecessary if statement --- .../artifacts/environment_variables.py | 14 ++++------ .../jumpstart/artifacts/hyperparameters.py | 7 ++--- .../jumpstart/artifacts/image_uris.py | 7 ++--- .../artifacts/incremental_training.py | 7 ++--- .../jumpstart/artifacts/instance_types.py | 14 ++++------ src/sagemaker/jumpstart/artifacts/kwargs.py | 28 ++++++++----------- .../jumpstart/artifacts/metric_definitions.py | 7 ++--- .../jumpstart/artifacts/model_packages.py | 7 ++--- .../jumpstart/artifacts/model_uris.py | 14 ++++------ src/sagemaker/jumpstart/artifacts/payloads.py | 7 ++--- .../jumpstart/artifacts/predictors.py | 28 ++++++++----------- .../jumpstart/artifacts/resource_names.py | 7 ++--- .../artifacts/resource_requirements.py | 7 ++--- .../jumpstart/artifacts/script_uris.py | 14 ++++------ src/sagemaker/jumpstart/utils.py | 7 ++--- src/sagemaker/jumpstart/validators.py | 7 ++--- 16 files changed, 78 insertions(+), 104 deletions(-) diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index c5b1893597..a6a8f5e7af 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -71,10 +71,9 @@ def _retrieve_default_environment_variables( dict: the inference environment variables to use for the model. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -170,10 +169,9 @@ def _retrieve_gated_model_uri_env_var_value( ValueError: If the model specs specified are invalid. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index e22d9d27dd..d19530ecfb 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -70,10 +70,9 @@ def _retrieve_default_hyperparameters( dict: the hyperparameters to use for the model. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index d162f7f9f5..9d19d5e069 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -104,10 +104,9 @@ def _retrieve_image_uri( known security vulnerabilities. DeprecatedJumpStartModelError: If the version of the model is deprecated. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index 8a70671362..1b3c6f4b29 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -58,10 +58,9 @@ def _model_supports_incremental_training( bool: the support status for incremental training. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 5f79243d7e..e7c9c5911d 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -76,10 +76,9 @@ def _retrieve_default_instance_type( specified region due to lack of supported computing instances. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -165,10 +164,9 @@ def _retrieve_instance_types( specified region due to lack of supported computing instances. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index 2246f0638a..9cd152b0bb 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -62,10 +62,9 @@ def _retrieve_model_init_kwargs( dict: the kwargs to use for the use case. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -123,10 +122,9 @@ def _retrieve_model_deploy_kwargs( dict: the kwargs to use for the use case. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -180,10 +178,9 @@ def _retrieve_estimator_init_kwargs( dict: the kwargs to use for the use case. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -239,10 +236,9 @@ def _retrieve_estimator_fit_kwargs( dict: the kwargs to use for the use case. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index 4bbef49620..57f66155c7 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -62,10 +62,9 @@ def _retrieve_default_training_metric_definitions( list: the default training metric definitions to use for the model or None. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index 64ed038afb..e0f654eadc 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -65,10 +65,9 @@ def _retrieve_model_package_arn( str: the model package arn to use for the model or None. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index 5a5bede37a..6bb2e576fc 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -129,10 +129,9 @@ def _retrieve_model_uri( known security vulnerabilities. DeprecatedJumpStartModelError: If the version of the model is deprecated. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -208,10 +207,9 @@ def _model_supports_training_model_uri( bool: the support status for model uri with training. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 6ed64de47d..3359e32732 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -63,10 +63,9 @@ def _retrieve_example_payloads( to the serializable payload object. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 2f352999c8..4f6dfe1fe3 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -309,10 +309,9 @@ def _retrieve_default_content_type( str: the default content type to use for the model. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -361,10 +360,9 @@ def _retrieve_default_accept_type( str: the default accept type to use for the model. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -414,10 +412,9 @@ def _retrieve_supported_accept_types( list: the supported accept types to use for the model. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -467,10 +464,9 @@ def _retrieve_supported_content_types( list: the supported content types to use for the model. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index db392b93f4..cffd46d043 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -60,10 +60,9 @@ def _retrieve_resource_name_base( str: the default resource name. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 212c558465..369acac85f 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -89,10 +89,9 @@ def _retrieve_default_resources( retrieve default resource requirements """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index cf53ac31dd..f69732d2e0 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -71,10 +71,9 @@ def _retrieve_script_uri( known security vulnerabilities. DeprecatedJumpStartModelError: If the version of the model is deprecated. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -134,10 +133,9 @@ def _model_supports_inference_script_uri( bool: the support status for script uri with inference. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 69e60ba35e..8078e4daea 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -572,10 +572,9 @@ def verify_model_region_and_return_specs( DeprecatedJumpStartModelError: If the version of the model is deprecated. """ - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) if scope is None: raise ValueError( diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index 65932ae245..c7098a1185 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -206,10 +206,9 @@ def validate_hyperparameters( if validation_mode is None: validation_mode = HyperparameterValidationMode.VALIDATE_PROVIDED - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, From 8576b327ad7d530d9911461bf94a0baf3d1c9c0e Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 12 Mar 2024 23:05:46 +0000 Subject: [PATCH 6/6] chore: remove unnecessary if statement (2) --- src/sagemaker/jumpstart/artifacts/model_packages.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index e0f654eadc..aa22351771 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -150,10 +150,9 @@ def _retrieve_model_package_model_artifact_s3_uri( if scope == JumpStartScriptScope.TRAINING: - if region is None: - region = region or get_region_fallback( - sagemaker_session=sagemaker_session, - ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id,