diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index 006559852c..fa5ae8900b 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -15,7 +15,6 @@ from typing import Callable, Dict, Optional, Set from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, ) @@ -24,6 +23,7 @@ ) from sagemaker.jumpstart.utils import ( get_jumpstart_gated_content_bucket, + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -72,8 +72,9 @@ def _retrieve_default_environment_variables( dict: the inference environment variables to use for the model. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -198,8 +199,9 @@ def _retrieve_gated_model_uri_env_var_value( ValueError: If the model specs specified are invalid. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index e9e6f613f8..d19530ecfb 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -15,13 +15,13 @@ from typing import Dict, Optional from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, VariableScope, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -70,8 +70,9 @@ def _retrieve_default_hyperparameters( dict: the hyperparameters to use for the model. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 6ea1ca84a1..9d19d5e069 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -17,13 +17,13 @@ from sagemaker import image_uris from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, ModelFramework, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -104,8 +104,9 @@ def _retrieve_image_uri( known security vulnerabilities. DeprecatedJumpStartModelError: If the version of the model is deprecated. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index 753a911422..1b3c6f4b29 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -15,12 +15,12 @@ from typing import Optional from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -58,8 +58,9 @@ def _model_supports_incremental_training( bool: the support status for incremental training. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 608303c5e6..e7c9c5911d 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -18,13 +18,13 @@ from sagemaker.jumpstart.exceptions import NO_AVAILABLE_INSTANCES_ERROR_MSG from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, JumpStartModelType, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -76,8 +76,9 @@ def _retrieve_default_instance_type( specified region due to lack of supported computing instances. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -163,8 +164,9 @@ def _retrieve_instance_types( specified region due to lack of supported computing instances. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index c15f686805..9cd152b0bb 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -18,13 +18,13 @@ from sagemaker.utils import volume_size_supported from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, JumpStartModelType, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) @@ -62,8 +62,9 @@ def _retrieve_model_init_kwargs( dict: the kwargs to use for the use case. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -121,8 +122,9 @@ def _retrieve_model_deploy_kwargs( dict: the kwargs to use for the use case. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -176,8 +178,9 @@ def _retrieve_estimator_init_kwargs( dict: the kwargs to use for the use case. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -233,8 +236,9 @@ def _retrieve_estimator_fit_kwargs( dict: the kwargs to use for the use case. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index b6f6019641..57f66155c7 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -16,12 +16,12 @@ from typing import Dict, List, Optional from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -62,8 +62,9 @@ def _retrieve_default_training_metric_definitions( list: the default training metric definitions to use for the model or None. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index 5c8a2488c6..aa22351771 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -15,9 +15,9 @@ from typing import Optional from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.jumpstart.enums import ( @@ -65,8 +65,9 @@ def _retrieve_model_package_arn( str: the model package arn to use for the model or None. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -149,8 +150,9 @@ def _retrieve_model_package_model_artifact_s3_uri( if scope == JumpStartScriptScope.TRAINING: - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index c41f0a75b7..6bb2e576fc 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -18,7 +18,6 @@ from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, @@ -26,6 +25,7 @@ from sagemaker.jumpstart.utils import ( get_jumpstart_content_bucket, get_jumpstart_gated_content_bucket, + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -129,8 +129,9 @@ def _retrieve_model_uri( known security vulnerabilities. DeprecatedJumpStartModelError: If the version of the model is deprecated. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -206,8 +207,9 @@ def _model_supports_training_model_uri( bool: the support status for model uri with training. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 0424145119..3359e32732 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -16,7 +16,6 @@ from typing import Dict, Optional from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, @@ -24,6 +23,7 @@ ) from sagemaker.jumpstart.types import JumpStartSerializablePayload from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -63,8 +63,9 @@ def _retrieve_example_payloads( to the serializable payload object. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index e9e0e8dfde..4f6dfe1fe3 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -20,7 +20,6 @@ CONTENT_TYPE_TO_SERIALIZER_TYPE_MAP, DEFAULT_JUMPSTART_SAGEMAKER_SESSION, DESERIALIZER_TYPE_TO_CLASS_MAP, - JUMPSTART_DEFAULT_REGION_NAME, SERIALIZER_TYPE_TO_CLASS_MAP, ) from sagemaker.jumpstart.enums import ( @@ -29,6 +28,7 @@ JumpStartModelType, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -309,8 +309,9 @@ def _retrieve_default_content_type( str: the default content type to use for the model. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -359,8 +360,9 @@ def _retrieve_default_accept_type( str: the default accept type to use for the model. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -410,8 +412,9 @@ def _retrieve_supported_accept_types( list: the supported accept types to use for the model. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -461,8 +464,9 @@ def _retrieve_supported_content_types( list: the supported content types to use for the model. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index 60af520a6e..cffd46d043 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -15,13 +15,13 @@ from typing import Optional from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, JumpStartModelType, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -60,8 +60,9 @@ def _retrieve_resource_name_base( str: the default resource name. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 9f01a7af77..369acac85f 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -17,13 +17,13 @@ from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, JumpStartModelType, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -89,8 +89,9 @@ def _retrieve_default_resources( retrieve default resource requirements """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index c1b037ce61..f69732d2e0 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -17,13 +17,13 @@ from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, ) from sagemaker.jumpstart.utils import ( get_jumpstart_content_bucket, + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -71,8 +71,9 @@ def _retrieve_script_uri( known security vulnerabilities. DeprecatedJumpStartModelError: If the version of the model is deprecated. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -132,8 +133,9 @@ def _model_supports_inference_script_uri( bool: the support status for script uri with inference. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 7682ab3817..fff421ab32 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -26,7 +26,6 @@ ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, - JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, MODEL_ID_LIST_WEB_URL, MODEL_TYPE_TO_MANIFEST_MAP, @@ -62,24 +61,21 @@ class JumpStartModelsCache: for launching JumpStart models from the SageMaker SDK. """ - # fmt: off def __init__( self, - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, max_s3_cache_items: int = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, - s3_cache_expiration_horizon: datetime.timedelta = - JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON, - max_semantic_version_cache_items: int = - JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, - semantic_version_cache_expiration_horizon: datetime.timedelta = - JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, - manifest_file_s3_key: str = - JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, + s3_cache_expiration_horizon: datetime.timedelta = JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON, + max_semantic_version_cache_items: int = JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, + semantic_version_cache_expiration_horizon: datetime.timedelta = ( + JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON + ), + manifest_file_s3_key: str = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, proprietary_manifest_s3_key: str = JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, s3_bucket_name: Optional[str] = None, s3_client_config: Optional[botocore.config.Config] = None, s3_client: Optional[boto3.client] = None, - ) -> None: # fmt: on + ) -> None: """Initialize a ``JumpStartModelsCache`` instance. Args: @@ -102,7 +98,10 @@ def __init__( s3_client (Optional[boto3.client]): s3 client to use. Default: None. """ - self._region = region + self._region = region or utils.get_region_fallback( + s3_bucket_name=s3_bucket_name, s3_client=s3_client + ) + self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue]( max_cache_items=max_s3_cache_items, expiration_horizon=s3_cache_expiration_horizon, @@ -165,9 +164,7 @@ def set_manifest_file_s3_key( } property_name = file_mapping.get(file_type) if not property_name: - raise ValueError( - self._file_type_error_msg(file_type, manifest_only=True) - ) + raise ValueError(self._file_type_error_msg(file_type, manifest_only=True)) if key != property_name: setattr(self, property_name, key) self.clear() @@ -180,9 +177,7 @@ def get_manifest_file_s3_key( return self._manifest_file_s3_key if file_type == JumpStartS3FileType.PROPRIETARY_MANIFEST: return self._proprietary_manifest_s3_key - raise ValueError( - self._file_type_error_msg(file_type, manifest_only=True) - ) + raise ValueError(self._file_type_error_msg(file_type, manifest_only=True)) def set_s3_bucket_name(self, s3_bucket_name: str) -> None: """Set s3 bucket used for cache.""" @@ -235,7 +230,8 @@ def _model_id_retrieval_function( sm_version = utils.get_sagemaker_version() manifest = self._s3_cache.get( JumpStartCachedS3ContentKey( - MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]) + MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type] + ) )[0].formatted_content versions_compatible_with_sagemaker = [ @@ -252,7 +248,8 @@ def _model_id_retrieval_function( return JumpStartVersionedModelId(model_id, sm_compatible_model_version) versions_incompatible_with_sagemaker = [ - Version(header.version) for header in manifest.values() # type: ignore + Version(header.version) + for header in manifest.values() # type: ignore if header.model_id == model_id ] sm_incompatible_model_version = self._select_version( @@ -282,9 +279,7 @@ def _model_id_retrieval_function( raise KeyError(error_msg) error_msg = f"Unable to find model manifest for '{model_id}' with version '{version}'. " - error_msg += ( - f"Visit {MODEL_ID_LIST_WEB_URL} for updated list of models. " - ) + error_msg += f"Visit {MODEL_ID_LIST_WEB_URL} for updated list of models. " other_model_id_version = None if model_type == JumpStartModelType.OPEN_WEIGHTS: @@ -293,19 +288,17 @@ def _model_id_retrieval_function( ) # all versions here are incompatible with sagemaker elif model_type == JumpStartModelType.PROPRIETARY: all_possible_model_id_version = [ - header.version for header in manifest.values() # type: ignore + header.version + for header in manifest.values() # type: ignore if header.model_id == model_id ] other_model_id_version = ( - None - if not all_possible_model_id_version - else all_possible_model_id_version[0] + None if not all_possible_model_id_version else all_possible_model_id_version[0] ) if other_model_id_version is not None: error_msg += ( - f"Consider using model ID '{model_id}' with version " - f"'{other_model_id_version}'." + f"Consider using model ID '{model_id}' with version " f"'{other_model_id_version}'." ) else: possible_model_ids = [header.model_id for header in manifest.values()] # type: ignore @@ -347,15 +340,15 @@ def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list], def _is_local_metadata_mode(self) -> bool: """Returns True if the cache should use local metadata mode, based off env variables.""" - return (ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ - and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE]) - and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ - and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE])) + return ( + ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ + and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE]) + and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ + and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE]) + ) def _get_json_file( - self, - key: str, - filetype: JumpStartS3FileType + self, key: str, filetype: JumpStartS3FileType ) -> Tuple[Union[dict, list], Optional[str]]: """Returns json file either from s3 or local file system. @@ -379,21 +372,19 @@ def _get_json_md5_hash(self, key: str): return self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=key)["ETag"] def _get_json_file_from_local_override( - self, - key: str, - filetype: JumpStartS3FileType + self, key: str, filetype: JumpStartS3FileType ) -> Union[dict, list]: """Reads json file from local filesystem and returns data.""" if filetype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: - metadata_local_root = ( - os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE] - ) + metadata_local_root = os.environ[ + ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE + ] elif filetype == JumpStartS3FileType.OPEN_WEIGHT_SPECS: metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE] else: raise ValueError(f"Unsupported file type for local override: {filetype}") file_path = os.path.join(metadata_local_root, key) - with open(file_path, 'r') as f: + with open(file_path, "r") as f: data = json.load(f) return data @@ -437,9 +428,7 @@ def _retrieval_function( model_specs = JumpStartModelSpecs(formatted_body) utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client) return JumpStartCachedS3ContentValue(formatted_content=model_specs) - raise ValueError( - self._file_type_error_msg(file_type) - ) + raise ValueError(self._file_type_error_msg(file_type)) def get_manifest( self, @@ -448,7 +437,8 @@ def get_manifest( """Return entire JumpStart models manifest.""" manifest_dict = self._s3_cache.get( JumpStartCachedS3ContentKey( - MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]) + MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type] + ) )[0].formatted_content manifest = list(manifest_dict.values()) # type: ignore return manifest @@ -505,16 +495,14 @@ def _select_version( except InvalidSpecifier: raise KeyError(f"Bad semantic version: {version_str}") available_versions_filtered = list(spec.filter(available_versions)) - return ( - str(max(available_versions_filtered)) if available_versions_filtered != [] else None - ) + return str(max(available_versions_filtered)) if available_versions_filtered != [] else None def _get_header_impl( self, model_id: str, semantic_version_str: str, attempt: int = 0, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelHeader: """Lower-level function to return header. @@ -537,7 +525,8 @@ def _get_header_impl( manifest = self._s3_cache.get( JumpStartCachedS3ContentKey( - MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]) + MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type] + ) )[0].formatted_content try: @@ -553,7 +542,7 @@ def get_specs( self, model_id: str, version_str: str, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """Return specs for a given JumpStart model ID and semantic version. @@ -566,16 +555,12 @@ def get_specs( header = self.get_header(model_id, version_str, model_type) spec_key = header.spec_key specs, cache_hit = self._s3_cache.get( - JumpStartCachedS3ContentKey( - MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key - ) + JumpStartCachedS3ContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key) ) if not cache_hit and "*" in version_str: JUMPSTART_LOGGER.warning( - get_wildcard_model_version_msg( - header.model_id, version_str, header.version - ) + get_wildcard_model_version_msg(header.model_id, version_str, header.version) ) return specs.formatted_content diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 4dada409f5..bac076ea4a 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -508,7 +508,7 @@ def _validate_model_id_and_get_type_hook(): return validate_model_id_and_get_type( model_id=model_id, model_version=model_version, - region=region, + region=region or getattr(sagemaker_session, "boto_region_name", None), script=JumpStartScriptScope.TRAINING, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 86630fcfb8..875ec9d003 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -192,8 +192,8 @@ def get_init_kwargs( estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(estimator_init_kwargs) - estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_sagemaker_session_to_kwargs(estimator_init_kwargs) + estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_instance_type_and_count_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_image_uri_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_model_uri_to_kwargs(estimator_init_kwargs) @@ -393,7 +393,9 @@ def get_deploy_kwargs( def _add_region_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: """Sets region in kwargs based on default or override, returns full kwargs.""" - kwargs.region = kwargs.region or JUMPSTART_DEFAULT_REGION_NAME + kwargs.region = ( + kwargs.region or kwargs.sagemaker_session.boto_region_name or JUMPSTART_DEFAULT_REGION_NAME + ) return kwargs @@ -507,6 +509,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + region=kwargs.region, instance_type=kwargs.instance_type, ) @@ -553,6 +556,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart model_version=kwargs.model_version, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + region=kwargs.region, sagemaker_session=kwargs.sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 63b4898877..28746990e3 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -129,7 +129,9 @@ def get_default_predictor( def _add_region_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets region kwargs based on default or override, returns full kwargs.""" - kwargs.region = kwargs.region or JUMPSTART_DEFAULT_REGION_NAME + kwargs.region = ( + kwargs.region or kwargs.sagemaker_session.boto_region_name or JUMPSTART_DEFAULT_REGION_NAME + ) return kwargs @@ -758,8 +760,8 @@ def get_init_kwargs( model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_sagemaker_session_to_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index a8f4c0e9fc..4529bc11b9 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -285,7 +285,7 @@ def _validate_model_id_and_type(): return validate_model_id_and_get_type( model_id=model_id, model_version=model_version, - region=region, + region=region or getattr(sagemaker_session, "boto_region_name", None), script=JumpStartScriptScope.INFERENCE, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 485354e802..85a041379a 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -23,7 +23,6 @@ from sagemaker.jumpstart import accessors from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, PROPRIETARY_MODEL_SPEC_PREFIX, ) from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType @@ -38,6 +37,7 @@ from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs from sagemaker.jumpstart.utils import ( get_jumpstart_content_bucket, + get_region_fallback, get_sagemaker_version, verify_model_region_and_return_specs, validate_model_id_and_get_type, @@ -156,7 +156,7 @@ def extract_model_type_filter_representation(spec_key: str) -> str: def list_jumpstart_tasks( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """List tasks for JumpStart, and optionally apply filters to result. @@ -168,11 +168,14 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin (e.g. ``"task == ic"``). If this argument is not supplied, all tasks will be listed. (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding - models. (Default: JUMPSTART_DEFAULT_REGION_NAME). + models. (Default: None). sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) tasks: Set[str] = set() for model_id, _ in _generate_jumpstart_model_versions( filter=filter, region=region, sagemaker_session=sagemaker_session @@ -184,7 +187,7 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin def list_jumpstart_frameworks( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """List frameworks for JumpStart, and optionally apply filters to result. @@ -196,11 +199,14 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin (eg. ``"task == ic"``). If this argument is not supplied, all frameworks will be listed. (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding - models. (Default: JUMPSTART_DEFAULT_REGION_NAME). + models. (Default: None). sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) frameworks: Set[str] = set() for model_id, _ in _generate_jumpstart_model_versions( filter=filter, region=region, sagemaker_session=sagemaker_session @@ -212,7 +218,7 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin def list_jumpstart_scripts( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """List scripts for JumpStart, and optionally apply filters to result. @@ -224,10 +230,13 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin (e.g. ``"task == ic"``). If this argument is not supplied, all scripts will be listed. (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding - models. (Default: JUMPSTART_DEFAULT_REGION_NAME). + models. (Default: None). sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) if (isinstance(filter, Constant) and filter.resolved_value == BooleanValues.TRUE) or ( isinstance(filter, str) and filter.lower() == BooleanValues.TRUE.lower() ): @@ -255,7 +264,7 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin def list_jumpstart_models( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, list_incomplete_models: bool = False, list_old_models: bool = False, list_versions: bool = False, @@ -270,7 +279,7 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin (e.g. ``"task == ic"``). If this argument is not supplied, all models will be listed. (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding - models. (Default: JUMPSTART_DEFAULT_REGION_NAME). + models. (Default: None). list_incomplete_models (bool): Optional. If a model does not contain metadata fields requested by the filter, and the filter cannot be resolved to a include/not include, whether the model should be included. By default, these models are omitted from results. @@ -283,6 +292,9 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_id_version_dict: Dict[str, List[str]] = dict() for model_id, version in _generate_jumpstart_model_versions( filter=filter, @@ -312,7 +324,7 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, list_incomplete_models: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Generator: @@ -325,7 +337,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin (e.g. ``"task == ic"``). If this argument is not supplied, all models will be generated. (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding - models. (Default: JUMPSTART_DEFAULT_REGION_NAME). + models. (Default: None). list_incomplete_models (bool): Optional. If a model does not contain metadata fields requested by the filter, and the filter cannot be resolved to a include/not include, whether the model should be included. By default, these models are omitted from @@ -334,6 +346,10 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) + prop_models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( region=region, s3_client=sagemaker_session.s3_client, @@ -484,7 +500,7 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, def get_model_url( model_id: str, model_version: str, - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieve web url describing pretrained model. @@ -493,7 +509,7 @@ def get_model_url( model_id (str): The model ID for which to retrieve the url. model_version (str): The model version for which to retrieve the url. region (str): Optional. The region from which to retrieve metadata. - (Default: JUMPSTART_DEFAULT_REGION_NAME) + (Default: None) sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to retrieve the model url. """ @@ -504,6 +520,9 @@ def get_model_url( sagemaker_session=sagemaker_session, ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, diff --git a/src/sagemaker/jumpstart/payload_utils.py b/src/sagemaker/jumpstart/payload_utils.py index 242118c56e..595f801598 100644 --- a/src/sagemaker/jumpstart/payload_utils.py +++ b/src/sagemaker/jumpstart/payload_utils.py @@ -22,12 +22,12 @@ from sagemaker.jumpstart.artifacts.payloads import _retrieve_example_payloads from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import MIMEType from sagemaker.jumpstart.types import JumpStartSerializablePayload from sagemaker.jumpstart.utils import ( get_jumpstart_content_bucket, + get_region_fallback, ) from sagemaker.session import Session @@ -125,12 +125,14 @@ class PayloadSerializer: def __init__( self, bucket: Optional[str] = None, - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, s3_client: Optional[boto3.client] = None, ) -> None: """Initializes PayloadSerializer object.""" self.bucket = bucket or get_jumpstart_content_bucket() - self.region = region + self.region = region or get_region_fallback( + s3_client=s3_client, + ) self.s3_client = s3_client def get_bytes_payload_with_s3_references( diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 62ccba7900..5f51173b24 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -540,7 +540,7 @@ def verify_model_region_and_return_specs( model_id: Optional[str], version: Optional[str], scope: Optional[str], - region: str, + region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -576,6 +576,10 @@ def verify_model_region_and_return_specs( DeprecatedJumpStartModelError: If the version of the model is deprecated. """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) + if scope is None: raise ValueError( "Must specify `model_scope` argument to retrieve model " @@ -842,3 +846,42 @@ def get_jumpstart_model_id_version_from_resource_arn( model_version = model_version_from_tag return model_id, model_version + + +def get_region_fallback( + s3_bucket_name: Optional[str] = None, + s3_client: Optional[boto3.client] = None, + sagemaker_session: Optional[Session] = None, +) -> str: + """Returns region to use for JumpStart functionality implicitly via session objects.""" + regions_in_s3_bucket_name: Set[str] = { + region + for region in constants.JUMPSTART_REGION_NAME_SET + if s3_bucket_name is not None + if region in s3_bucket_name + } + regions_in_s3_client_endpoint_url: Set[str] = { + region + for region in constants.JUMPSTART_REGION_NAME_SET + if s3_client is not None + if region in s3_client._endpoint.host + } + + regions_in_sagemaker_session: Set[str] = { + region + for region in constants.JUMPSTART_REGION_NAME_SET + if sagemaker_session + if region == sagemaker_session.boto_region_name + } + + combined_regions = regions_in_s3_client_endpoint_url.union( + regions_in_s3_bucket_name, regions_in_sagemaker_session + ) + + if len(combined_regions) > 1: + raise ValueError("Unable to resolve a region name from the s3 bucket and client provided.") + + if len(combined_regions) == 0: + return constants.JUMPSTART_DEFAULT_REGION_NAME + + return list(combined_regions)[0] diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index 3199e5fc2e..c7098a1185 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -14,7 +14,6 @@ from __future__ import absolute_import from typing import Any, Dict, List, Optional from sagemaker import session -from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.jumpstart.enums import ( HyperparameterValidationMode, @@ -24,7 +23,7 @@ ) from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError from sagemaker.jumpstart.types import JumpStartHyperparameter -from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.utils import get_region_fallback, verify_model_region_and_return_specs def _validate_hyperparameter( @@ -168,7 +167,7 @@ def validate_hyperparameters( model_version: str, hyperparameters: Dict[str, Any], validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, - region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, sagemaker_session: Optional[session.Session] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -184,8 +183,7 @@ def validate_hyperparameters( to this function will be validated, the missing hyperparameters will be ignored. If set to``VALIDATE_ALGORITHM``, all algorithm hyperparameters will be validated. If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated. - region (str): Region for which to validate hyperparameters. (Default: JumpStart - default region). + region (str): Region for which to validate hyperparameters. (Default: None). sagemaker_session (Optional[Session]): Custom SageMaker Session to use. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -202,11 +200,15 @@ def validate_hyperparameters( """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) if validation_mode is None: validation_mode = HyperparameterValidationMode.VALIDATE_PROVIDED - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py index 49c18beec2..11165a0625 100644 --- a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py +++ b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py @@ -23,7 +23,8 @@ mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +region = "us-west-2" +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") diff --git a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py index 7765d6eaad..d116c8121b 100644 --- a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py +++ b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py @@ -22,7 +22,8 @@ from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +region = "us-west-2" +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") diff --git a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py index 5328533da5..f0102068e7 100644 --- a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py +++ b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py @@ -22,9 +22,10 @@ from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec +region = "us-west-2" mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") diff --git a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py index cc1aad8a44..5f00f93abf 100644 --- a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py +++ b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py @@ -25,7 +25,8 @@ mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +region = "us-west-2" +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py index a13fba87ae..40ee4978cf 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py @@ -24,7 +24,8 @@ mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +region = "us-west-2" +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index 7a5df4ac93..07418f8ddb 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -23,8 +23,9 @@ from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec +region = "us-west-2" mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index 6c80c97f33..88b95b9403 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -37,8 +37,9 @@ def test_jumpstart_common_image_uri( patched_get_model_specs.side_effect = get_spec_from_base_spec patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS + region = "us-west-2" mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + mock_session = Mock(s3_client=mock_client, boto_region_name=region) image_uris.retrieve( framework=None, diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index 982c7f1702..2e51afd3f7 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -34,7 +34,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode region = "us-west-2" mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + mock_session = Mock(s3_client=mock_client, boto_region_name=region) default_training_instance_types = instance_types.retrieve_default( region=region, @@ -178,7 +178,8 @@ def test_jumpstart_inference_instance_type_variants(patched_get_model_specs): patched_get_model_specs.side_effect = get_special_model_spec mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) model_id, model_version = "inference-instance-types-variant-model", "*" region = "us-west-2" diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index fe4b122c4a..1e048ef0dd 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -16,6 +16,7 @@ from unittest import mock import unittest from inspect import signature +from mock import Mock import pytest from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig @@ -30,12 +31,16 @@ from sagemaker.jumpstart.artifacts.metric_definitions import ( _retrieve_default_training_metric_definitions, ) -from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + JUMPSTART_DEFAULT_REGION_NAME, +) from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag, JumpStartModelType from sagemaker.jumpstart.estimator import JumpStartEstimator from sagemaker.jumpstart.utils import get_jumpstart_content_bucket +from sagemaker.session import Session from sagemaker.session_settings import SessionSettings from tests.integ.sagemaker.jumpstart.utils import get_training_dataset_for_model_and_version from sagemaker.model import Model @@ -44,6 +49,7 @@ get_special_model_spec, overwrite_dictionary, ) +import boto3 execution_role = "fake role! do not use!" @@ -1553,7 +1559,8 @@ def test_training_passes_session_to_deploy( mock_get_model_specs.side_effect = get_special_model_spec mock_role = f"dsfsdfsd{time.time()}" - mock_sagemaker_session = mock.MagicMock(sagemaker_config={}) + region = "us-west-2" + mock_sagemaker_session = mock.MagicMock(sagemaker_config={}, boto_region_name=region) mock_sagemaker_session.get_caller_identity_arn = lambda: mock_role estimator = JumpStartEstimator( @@ -1758,6 +1765,56 @@ def test_model_artifact_variant_estimator( ], ) + @mock.patch("sagemaker.jumpstart.estimator.get_default_predictor") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_jumpstart_estimator_session( + self, + mock_get_model_specs: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, + mock_deploy, + mock_fit, + mock_init, + get_default_predictor, + ): + + mock_validate_model_id_and_get_type.return_value = True + + model_id, _ = "js-trainable-model", "*" + + mock_get_model_specs.side_effect = get_special_model_spec + + region = "eu-west-1" # some non-default region + + if region == JUMPSTART_DEFAULT_REGION_NAME: + region = "us-west-2" + + session = Session(boto_session=boto3.session.Session(region_name=region)) + + assert session.boto_region_name != JUMPSTART_DEFAULT_REGION_NAME + + session.get_caller_identity_arn = Mock(return_value="blah") + + estimator = JumpStartEstimator(model_id=model_id, sagemaker_session=session) + estimator.fit() + + estimator.deploy() + + assert len(mock_get_model_specs.call_args_list) > 1 + + regions = {call[1]["region"] for call in mock_get_model_specs.call_args_list} + + assert len(regions) == 1 + assert list(regions)[0] == region + + s3_clients = {call[1]["s3_client"] for call in mock_get_model_specs.call_args_list} + assert len(s3_clients) == 1 + assert list(s3_clients)[0] == session.s3_client + def test_jumpstart_estimator_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index ba4ba0bb13..c4a96d4120 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -15,18 +15,22 @@ from typing import Optional, Set from unittest import mock import unittest -from mock import MagicMock +from mock import MagicMock, Mock import pytest from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.jumpstart.artifacts.environment_variables import ( _retrieve_default_environment_variables, ) -from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + JUMPSTART_DEFAULT_REGION_NAME, +) from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag, JumpStartModelType from sagemaker.jumpstart.model import JumpStartModel from sagemaker.model import Model from sagemaker.predictor import Predictor +from sagemaker.session import Session from sagemaker.session_settings import SessionSettings from sagemaker.enums import EndpointType from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements @@ -38,6 +42,7 @@ get_special_model_spec_for_inference_component_based_endpoint, get_prototype_manifest, ) +import boto3 execution_role = "fake role! do not use!" region = "us-west-2" @@ -950,7 +955,7 @@ def test_jumpstart_model_tags( mock_get_model_specs.side_effect = get_special_model_spec - mock_session = MagicMock(sagemaker_config={}) + mock_session = MagicMock(sagemaker_config={}, boto_region_name="us-west-2") model = JumpStartModel(model_id=model_id, sagemaker_session=mock_session) @@ -987,7 +992,9 @@ def test_jumpstart_model_tags_disabled( mock_get_model_specs.side_effect = get_special_model_spec settings = SessionSettings(include_jumpstart_tags=False) - mock_session = MagicMock(sagemaker_config={}, settings=settings) + mock_session = MagicMock( + sagemaker_config={}, settings=settings, boto_region_name="us-west-2" + ) model = JumpStartModel(model_id=model_id, sagemaker_session=mock_session) @@ -1018,7 +1025,7 @@ def test_jumpstart_model_package_arn( mock_get_model_specs.side_effect = get_special_model_spec - mock_session = MagicMock(sagemaker_config={}) + mock_session = MagicMock(sagemaker_config={}, boto_region_name="us-west-2") model = JumpStartModel(model_id=model_id, sagemaker_session=mock_session) @@ -1053,7 +1060,7 @@ def test_jumpstart_model_package_arn_override( mock_get_model_specs.side_effect = get_special_model_spec - mock_session = MagicMock(sagemaker_config={}) + mock_session = MagicMock(sagemaker_config={}, boto_region_name="us-west-2") model_package_arn = ( "arn:aws:sagemaker:us-west-2:867530986753:model-package/" @@ -1312,6 +1319,52 @@ def test_model_registry_accept_and_response_types( response_types=["application/json;verbose", "application/json"], ) + @mock.patch("sagemaker.jumpstart.model.get_default_predictor") + @mock.patch("sagemaker.jumpstart.model.Model.__init__") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_jumpstart_model_session( + self, + mock_get_model_specs: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, + mock_deploy, + mock_init, + get_default_predictor, + ): + + mock_validate_model_id_and_get_type.return_value = True + + model_id, _ = "model_data_s3_prefix_model", "*" + + mock_get_model_specs.side_effect = get_special_model_spec + + region = "eu-west-1" # some non-default region + + if region == JUMPSTART_DEFAULT_REGION_NAME: + region = "us-west-2" + + session = Session(boto_session=boto3.session.Session(region_name=region)) + + assert session.boto_region_name != JUMPSTART_DEFAULT_REGION_NAME + + session.get_caller_identity_arn = Mock(return_value="blah") + + model = JumpStartModel(model_id=model_id, sagemaker_session=session) + model.deploy() + + assert len(mock_get_model_specs.call_args_list) > 1 + + regions = {call[1]["region"] for call in mock_get_model_specs.call_args_list} + + assert len(regions) == 1 + assert list(regions)[0] == region + + s3_clients = {call[1]["s3_client"] for call in mock_get_model_specs.call_args_list} + assert len(s3_clients) == 1 + assert list(s3_clients)[0] == session.s3_client + def test_jumpstart_model_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/test_artifacts.py b/tests/unit/sagemaker/jumpstart/test_artifacts.py index 21112926a5..3d9b5cef6a 100644 --- a/tests/unit/sagemaker/jumpstart/test_artifacts.py +++ b/tests/unit/sagemaker/jumpstart/test_artifacts.py @@ -329,7 +329,8 @@ def test_estimator_fit_kwargs(self, patched_get_model_specs): class RetrieveModelPackageArnTest(unittest.TestCase): - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -439,7 +440,8 @@ def test_retrieve_model_package_arn( class PrivateJumpStartBucketTest(unittest.TestCase): - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index cb54722d48..c81d5639e5 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -15,7 +15,9 @@ from unittest import TestCase from mock.mock import Mock, patch import pytest +import boto3 import random +from sagemaker import session from sagemaker.jumpstart import utils from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -1450,3 +1452,50 @@ def test_logger_disabled(self, mocked_emit: Mock): JUMPSTART_LOGGER.warning("Self destruct in 3...2...1...") mocked_emit.assert_not_called() + + +@pytest.mark.parametrize( + "s3_bucket_name, s3_client, sagemaker_session, region", + [ + ( + "jumpstart-cache-prod", + boto3.client("s3", region_name="blah-blah"), + session.Session(boto3.Session(region_name="blah-blah")), + JUMPSTART_DEFAULT_REGION_NAME, + ), + ( + "jumpstart-cache-prod-us-west-2", + boto3.client("s3", region_name="us-west-2"), + session.Session(boto3.Session(region_name="us-west-2")), + "us-west-2", + ), + ("jumpstart-cache-prod", boto3.client("s3", region_name="us-east-2"), None, "us-east-2"), + ], +) +def test_get_region_fallback_success(s3_bucket_name, s3_client, sagemaker_session, region): + assert region == utils.get_region_fallback(s3_bucket_name, s3_client, sagemaker_session) + + +@pytest.mark.parametrize( + "s3_bucket_name, s3_client, sagemaker_session", + [ + ( + "jumpstart-cache-prod-us-west-2", + boto3.client("s3", region_name="us-east-2"), + session.Session(boto3.Session(region_name="us-west-2")), + ), + ( + "jumpstart-cache-prod-us-west-2", + boto3.client("s3", region_name="us-west-2"), + session.Session(boto3.Session(region_name="eu-north-1")), + ), + ( + "jumpstart-cache-prod-us-west-2-us-east-2", + boto3.client("s3", region_name="us-east-2"), + None, + ), + ], +) +def test_get_region_fallback_failure(s3_bucket_name, s3_client, sagemaker_session): + with pytest.raises(ValueError): + utils.get_region_fallback(s3_bucket_name, s3_client, sagemaker_session) diff --git a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py index 608a32a005..835a09a58c 100644 --- a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py +++ b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py @@ -23,7 +23,8 @@ from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +region = "us-west-2" +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @@ -36,7 +37,8 @@ def test_jumpstart_default_metric_definitions( patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) model_id = "pytorch-ic-mobilenet-v2" region = "us-west-2" diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 8d75731b06..8ec9478d8a 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -39,7 +39,8 @@ def test_jumpstart_common_model_uri( patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) model_uris.retrieve( model_scope="training", diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index 7b5e7a598d..1c0cfa35b3 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -37,7 +37,7 @@ def test_jumpstart_resource_requirements( patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS region = "us-west-2" mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + mock_session = Mock(s3_client=mock_client, boto_region_name=region) model_id, model_version = "huggingface-llm-mistral-7b-instruct", "*" default_inference_resource_requirements = resource_requirements.retrieve_default( @@ -121,7 +121,7 @@ def test_jumpstart_no_supported_resource_requirements( model_id, model_version = "no-supported-instance-types-model", "*" region = "us-west-2" mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + mock_session = Mock(s3_client=mock_client, boto_region_name=region) default_inference_resource_requirements = resource_requirements.retrieve_default( region=region, diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index c797ba3559..16b7256ed2 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -39,7 +39,8 @@ def test_jumpstart_common_script_uri( patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) script_uris.retrieve( script_scope="training", diff --git a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py index c2253726bf..90ec5df6b5 100644 --- a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py +++ b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py @@ -40,7 +40,7 @@ def test_jumpstart_default_serializers( model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + mock_session = Mock(s3_client=mock_client, boto_region_name=region) default_serializer = serializers.retrieve_default( region=region, @@ -75,7 +75,8 @@ def test_jumpstart_serializer_options( patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) model_id, model_version = "predictor-specs-model", "*" region = "us-west-2"