Skip to content

Commit

Permalink
fix: sagemaker session region not being used (#4469)
Browse files Browse the repository at this point in the history
* fix: sagemaker session region not being used

* chore: add unit tests

* fix: remove all JUMPSTART_DEFAULT_REGION_NAME default arguments

* chore: use get_region_fallback throughout

* chore: remove unnecessary if statement

* chore: remove unnecessary if statement (2)

---------

Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com>
  • Loading branch information
evakravi and benieric authored Mar 13, 2024
1 parent 064378d commit 377be87
Show file tree
Hide file tree
Showing 40 changed files with 437 additions and 181 deletions.
12 changes: 7 additions & 5 deletions src/sagemaker/jumpstart/artifacts/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Callable, Dict, Optional, Set
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
JUMPSTART_LOGGER,
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
)
Expand All @@ -24,6 +23,7 @@
)
from sagemaker.jumpstart.utils import (
get_jumpstart_gated_content_bucket,
get_region_fallback,
verify_model_region_and_return_specs,
)
from sagemaker.session import Session
Expand Down Expand Up @@ -72,8 +72,9 @@ def _retrieve_default_environment_variables(
dict: the inference environment variables to use for the model.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down Expand Up @@ -198,8 +199,9 @@ def _retrieve_gated_model_uri_env_var_value(
ValueError: If the model specs specified are invalid.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
7 changes: 4 additions & 3 deletions src/sagemaker/jumpstart/artifacts/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
from typing import Dict, Optional
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
VariableScope,
)
from sagemaker.jumpstart.utils import (
get_region_fallback,
verify_model_region_and_return_specs,
)
from sagemaker.session import Session
Expand Down Expand Up @@ -70,8 +70,9 @@ def _retrieve_default_hyperparameters(
dict: the hyperparameters to use for the model.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
7 changes: 4 additions & 3 deletions src/sagemaker/jumpstart/artifacts/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from sagemaker import image_uris
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
ModelFramework,
)
from sagemaker.jumpstart.utils import (
get_region_fallback,
verify_model_region_and_return_specs,
)
from sagemaker.session import Session
Expand Down Expand Up @@ -104,8 +104,9 @@ def _retrieve_image_uri(
known security vulnerabilities.
DeprecatedJumpStartModelError: If the version of the model is deprecated.
"""
if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
7 changes: 4 additions & 3 deletions src/sagemaker/jumpstart/artifacts/incremental_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from typing import Optional
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
)
from sagemaker.jumpstart.utils import (
get_region_fallback,
verify_model_region_and_return_specs,
)
from sagemaker.session import Session
Expand Down Expand Up @@ -58,8 +58,9 @@ def _model_supports_incremental_training(
bool: the support status for incremental training.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
12 changes: 7 additions & 5 deletions src/sagemaker/jumpstart/artifacts/instance_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
from sagemaker.jumpstart.exceptions import NO_AVAILABLE_INSTANCES_ERROR_MSG
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
JumpStartModelType,
)
from sagemaker.jumpstart.utils import (
get_region_fallback,
verify_model_region_and_return_specs,
)
from sagemaker.session import Session
Expand Down Expand Up @@ -76,8 +76,9 @@ def _retrieve_default_instance_type(
specified region due to lack of supported computing instances.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down Expand Up @@ -163,8 +164,9 @@ def _retrieve_instance_types(
specified region due to lack of supported computing instances.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
22 changes: 13 additions & 9 deletions src/sagemaker/jumpstart/artifacts/kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
from sagemaker.utils import volume_size_supported
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
JumpStartModelType,
)
from sagemaker.jumpstart.utils import (
get_region_fallback,
verify_model_region_and_return_specs,
)

Expand Down Expand Up @@ -62,8 +62,9 @@ def _retrieve_model_init_kwargs(
dict: the kwargs to use for the use case.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down Expand Up @@ -121,8 +122,9 @@ def _retrieve_model_deploy_kwargs(
dict: the kwargs to use for the use case.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down Expand Up @@ -176,8 +178,9 @@ def _retrieve_estimator_init_kwargs(
dict: the kwargs to use for the use case.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down Expand Up @@ -233,8 +236,9 @@ def _retrieve_estimator_fit_kwargs(
dict: the kwargs to use for the use case.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
7 changes: 4 additions & 3 deletions src/sagemaker/jumpstart/artifacts/metric_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from typing import Dict, List, Optional
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
)
from sagemaker.jumpstart.utils import (
get_region_fallback,
verify_model_region_and_return_specs,
)
from sagemaker.session import Session
Expand Down Expand Up @@ -62,8 +62,9 @@ def _retrieve_default_training_metric_definitions(
list: the default training metric definitions to use for the model or None.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
12 changes: 7 additions & 5 deletions src/sagemaker/jumpstart/artifacts/model_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from typing import Optional
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.utils import (
get_region_fallback,
verify_model_region_and_return_specs,
)
from sagemaker.jumpstart.enums import (
Expand Down Expand Up @@ -65,8 +65,9 @@ def _retrieve_model_package_arn(
str: the model package arn to use for the model or None.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down Expand Up @@ -149,8 +150,9 @@ def _retrieve_model_package_model_artifact_s3_uri(

if scope == JumpStartScriptScope.TRAINING:

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
12 changes: 7 additions & 5 deletions src/sagemaker/jumpstart/artifacts/model_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
)
from sagemaker.jumpstart.utils import (
get_jumpstart_content_bucket,
get_jumpstart_gated_content_bucket,
get_region_fallback,
verify_model_region_and_return_specs,
)
from sagemaker.session import Session
Expand Down Expand Up @@ -129,8 +129,9 @@ def _retrieve_model_uri(
known security vulnerabilities.
DeprecatedJumpStartModelError: If the version of the model is deprecated.
"""
if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down Expand Up @@ -206,8 +207,9 @@ def _model_supports_training_model_uri(
bool: the support status for model uri with training.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
7 changes: 4 additions & 3 deletions src/sagemaker/jumpstart/artifacts/payloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
from typing import Dict, Optional
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
JumpStartModelType,
)
from sagemaker.jumpstart.types import JumpStartSerializablePayload
from sagemaker.jumpstart.utils import (
get_region_fallback,
verify_model_region_and_return_specs,
)
from sagemaker.session import Session
Expand Down Expand Up @@ -63,8 +63,9 @@ def _retrieve_example_payloads(
to the serializable payload object.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
Expand Down
Loading

0 comments on commit 377be87

Please sign in to comment.