Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: sagemaker session region not being used #4469

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/sagemaker/jumpstart/artifacts/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Callable, Dict, Optional, Set
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
JUMPSTART_LOGGER,
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
)
Expand All @@ -24,6 +23,7 @@
)
from sagemaker.jumpstart.utils import (
get_jumpstart_gated_content_bucket,
get_region_fallback,
verify_model_region_and_return_specs,
)
from sagemaker.session import Session
Expand Down Expand Up @@ -72,8 +72,9 @@ def _retrieve_default_environment_variables(
dict: the inference environment variables to use for the model.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down Expand Up @@ -198,8 +199,9 @@ def _retrieve_gated_model_uri_env_var_value(
ValueError: If the model specs specified are invalid.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
7 changes: 4 additions & 3 deletions src/sagemaker/jumpstart/artifacts/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
from typing import Dict, Optional
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
VariableScope,
)
from sagemaker.jumpstart.utils import (
get_region_fallback,
verify_model_region_and_return_specs,
)
from sagemaker.session import Session
Expand Down Expand Up @@ -70,8 +70,9 @@ def _retrieve_default_hyperparameters(
dict: the hyperparameters to use for the model.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
7 changes: 4 additions & 3 deletions src/sagemaker/jumpstart/artifacts/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from sagemaker import image_uris
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
ModelFramework,
)
from sagemaker.jumpstart.utils import (
get_region_fallback,
verify_model_region_and_return_specs,
)
from sagemaker.session import Session
Expand Down Expand Up @@ -104,8 +104,9 @@ def _retrieve_image_uri(
known security vulnerabilities.
DeprecatedJumpStartModelError: If the version of the model is deprecated.
"""
if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
7 changes: 4 additions & 3 deletions src/sagemaker/jumpstart/artifacts/incremental_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from typing import Optional
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
)
from sagemaker.jumpstart.utils import (
get_region_fallback,
verify_model_region_and_return_specs,
)
from sagemaker.session import Session
Expand Down Expand Up @@ -58,8 +58,9 @@ def _model_supports_incremental_training(
bool: the support status for incremental training.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
12 changes: 7 additions & 5 deletions src/sagemaker/jumpstart/artifacts/instance_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
from sagemaker.jumpstart.exceptions import NO_AVAILABLE_INSTANCES_ERROR_MSG
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
JumpStartModelType,
)
from sagemaker.jumpstart.utils import (
get_region_fallback,
verify_model_region_and_return_specs,
)
from sagemaker.session import Session
Expand Down Expand Up @@ -76,8 +76,9 @@ def _retrieve_default_instance_type(
specified region due to lack of supported computing instances.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down Expand Up @@ -163,8 +164,9 @@ def _retrieve_instance_types(
specified region due to lack of supported computing instances.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
22 changes: 13 additions & 9 deletions src/sagemaker/jumpstart/artifacts/kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
from sagemaker.utils import volume_size_supported
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
JumpStartModelType,
)
from sagemaker.jumpstart.utils import (
get_region_fallback,
verify_model_region_and_return_specs,
)

Expand Down Expand Up @@ -62,8 +62,9 @@ def _retrieve_model_init_kwargs(
dict: the kwargs to use for the use case.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down Expand Up @@ -121,8 +122,9 @@ def _retrieve_model_deploy_kwargs(
dict: the kwargs to use for the use case.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down Expand Up @@ -176,8 +178,9 @@ def _retrieve_estimator_init_kwargs(
dict: the kwargs to use for the use case.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down Expand Up @@ -233,8 +236,9 @@ def _retrieve_estimator_fit_kwargs(
dict: the kwargs to use for the use case.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
7 changes: 4 additions & 3 deletions src/sagemaker/jumpstart/artifacts/metric_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from typing import Dict, List, Optional
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
)
from sagemaker.jumpstart.utils import (
get_region_fallback,
verify_model_region_and_return_specs,
)
from sagemaker.session import Session
Expand Down Expand Up @@ -62,8 +62,9 @@ def _retrieve_default_training_metric_definitions(
list: the default training metric definitions to use for the model or None.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
12 changes: 7 additions & 5 deletions src/sagemaker/jumpstart/artifacts/model_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from typing import Optional
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.utils import (
get_region_fallback,
verify_model_region_and_return_specs,
)
from sagemaker.jumpstart.enums import (
Expand Down Expand Up @@ -65,8 +65,9 @@ def _retrieve_model_package_arn(
str: the model package arn to use for the model or None.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down Expand Up @@ -149,8 +150,9 @@ def _retrieve_model_package_model_artifact_s3_uri(

if scope == JumpStartScriptScope.TRAINING:

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
12 changes: 7 additions & 5 deletions src/sagemaker/jumpstart/artifacts/model_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
)
from sagemaker.jumpstart.utils import (
get_jumpstart_content_bucket,
get_jumpstart_gated_content_bucket,
get_region_fallback,
verify_model_region_and_return_specs,
)
from sagemaker.session import Session
Expand Down Expand Up @@ -129,8 +129,9 @@ def _retrieve_model_uri(
known security vulnerabilities.
DeprecatedJumpStartModelError: If the version of the model is deprecated.
"""
if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down Expand Up @@ -206,8 +207,9 @@ def _model_supports_training_model_uri(
bool: the support status for model uri with training.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
7 changes: 4 additions & 3 deletions src/sagemaker/jumpstart/artifacts/payloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
from typing import Dict, Optional
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
JumpStartModelType,
)
from sagemaker.jumpstart.types import JumpStartSerializablePayload
from sagemaker.jumpstart.utils import (
get_region_fallback,
verify_model_region_and_return_specs,
)
from sagemaker.session import Session
Expand Down Expand Up @@ -63,8 +63,9 @@ def _retrieve_example_payloads(
to the serializable payload object.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
Loading
Loading