Skip to content

Commit

Permalink
merge with master-curated-jumpstart
Browse files Browse the repository at this point in the history
  • Loading branch information
bencrabtree committed Mar 18, 2024
1 parent 2eff8fb commit b50c557
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 104 deletions.
3 changes: 1 addition & 2 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
MODEL_TYPE_TO_MANIFEST_MAP,
MODEL_TYPE_TO_SPECS_MAP,
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
from sagemaker.jumpstart.exceptions import (
get_wildcard_model_version_msg,
Expand Down Expand Up @@ -443,7 +442,7 @@ def _retrieval_function(
formatted_content=utils.get_formatted_manifest(formatted_body),
md5_hash=etag,
)

if data_type in {
JumpStartS3FileType.OPEN_WEIGHT_SPECS,
JumpStartS3FileType.PROPRIETARY_SPECS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
from sagemaker.jumpstart.curated_hub.utils import create_s3_object_reference_from_uri
from sagemaker.jumpstart.enums import JumpStartScriptScope
from sagemaker.jumpstart.types import JumpStartModelSpecs
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket, get_jumpstart_gated_content_bucket
from sagemaker.jumpstart.utils import (
get_jumpstart_content_bucket,
get_jumpstart_gated_content_bucket,
)


class PublicModelDataAccessor:
Expand Down
4 changes: 1 addition & 3 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import os
import re
from typing import Any, Dict, List, Set, Optional, Tuple, Union
import re
from typing import Any, Dict, List, Set, Optional, Tuple, Union
from urllib.parse import urlparse
import boto3
from packaging.version import Version
Expand Down Expand Up @@ -876,7 +874,7 @@ def extract_info_from_hub_content_arn(
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
"""Extracts hub_name, content_name, and content_version from a HubContentArn"""

match = re.match(constants.HUB_MODEL_ARN_REGEX, arn)
match = re.match(constants.HUB_CONTENT_ARN_REGEX, arn)
if match:
hub_name = match.group(4)
hub_region = match.group(2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def test_s3_path_file_generator_with_no_objects(s3_client):
s3_client.list_objects_v2.assert_called_once()
assert response == []


@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
def test_specs_file_generator_training_unsupported(patched_get_model_specs, s3_client):
specs = Mock()
Expand All @@ -154,6 +155,7 @@ def test_specs_file_generator_training_unsupported(patched_get_model_specs, s3_c
),
]


@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
def test_specs_file_generator_gated_model(patched_get_model_specs, s3_client):
specs = Mock()
Expand All @@ -167,4 +169,4 @@ def test_specs_file_generator_gated_model(patched_get_model_specs, s3_client):

response = generate_file_infos_from_model_specs(specs, {}, "us-west-2", s3_client)

assert response == []
assert response == []
63 changes: 0 additions & 63 deletions tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,69 +177,6 @@ def test_create_hub_bucket_if_it_does_not_exist():
assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn


def test_generate_default_hub_bucket_name():
mock_sagemaker_session = Mock()
mock_sagemaker_session.account_id.return_value = "123456789123"
mock_sagemaker_session.boto_region_name = "us-east-1"

assert (
utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session)
== "sagemaker-hubs-us-east-1-123456789123"
)


def test_create_hub_bucket_if_it_does_not_exist():
mock_sagemaker_session = Mock()
mock_sagemaker_session.account_id.return_value = "123456789123"
mock_sagemaker_session.client("sts").get_caller_identity.return_value = {
"Account": "123456789123"
}
hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub"
# Mock custom session with custom values
mock_custom_session = Mock()
mock_custom_session.account_id.return_value = "000000000000"
mock_custom_session.boto_region_name = "us-east-2"
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
mock_sagemaker_session.boto_region_name = "us-east-1"

bucket_name = "sagemaker-hubs-us-east-1-123456789123"
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
sagemaker_session=mock_sagemaker_session
)

mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
assert created_hub_bucket_name == bucket_name
assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn


def test_generate_default_hub_bucket_name():
mock_sagemaker_session = Mock()
mock_sagemaker_session.account_id.return_value = "123456789123"
mock_sagemaker_session.boto_region_name = "us-east-1"

assert (
utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session)
== "sagemaker-hubs-us-east-1-123456789123"
)


def test_create_hub_bucket_if_it_does_not_exist():
mock_sagemaker_session = Mock()
mock_sagemaker_session.account_id.return_value = "123456789123"
mock_sagemaker_session.client("sts").get_caller_identity.return_value = {
"Account": "123456789123"
}
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
mock_sagemaker_session.boto_region_name = "us-east-1"
bucket_name = "sagemaker-hubs-us-east-1-123456789123"
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
sagemaker_session=mock_sagemaker_session
)

mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
assert created_hub_bucket_name == bucket_name


def test_is_gated_bucket():
assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-west-2") is True

Expand Down
33 changes: 1 addition & 32 deletions tests/unit/sagemaker/jumpstart/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_jumpstart_models_cache_get_model_specs(mock_cache):
accessors.JumpStartModelsAccessor.get_model_specs(
region=region, model_id=model_id, version=version
)
mock_cache.get_specs.assert_called_once_with(model_id=model_id, semantic_version_str=version)
mock_cache.get_specs.assert_called_once_with(model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS)
mock_cache.get_hub_model.assert_not_called()

accessors.JumpStartModelsAccessor.get_model_specs(
Expand Down Expand Up @@ -139,37 +139,6 @@ def test_jumpstart_proprietary_models_cache_get(mock_cache):
)


@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
def test_jumpstart_models_cache_get_model_specs(mock_cache):
mock_cache.get_specs = Mock()
mock_cache.get_hub_model = Mock()
model_id, version = "pytorch-ic-mobilenet-v2", "*"
region = "us-west-2"

accessors.JumpStartModelsAccessor.get_model_specs(
region=region, model_id=model_id, version=version
)
mock_cache.get_specs.assert_called_once_with(
model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS
)
mock_cache.get_hub_model.assert_not_called()

accessors.JumpStartModelsAccessor.get_model_specs(
region=region,
model_id=model_id,
version=version,
hub_arn=f"arn:aws:sagemaker:{region}:123456789123:hub/my-mock-hub",
)
mock_cache.get_hub_model.assert_called_once_with(
hub_model_arn=(
f"arn:aws:sagemaker:{region}:123456789123:hub-content/my-mock-hub/Model/{model_id}/{version}"
)
)

# necessary because accessors is a static module
reload(accessors)


@patch("sagemaker.jumpstart.cache.JumpStartModelsCache")
def test_jumpstart_models_cache_set_reset(mock_model_cache: Mock):

Expand Down
2 changes: 0 additions & 2 deletions tests/unit/sagemaker/script_uris/jumpstart/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ def test_jumpstart_common_script_uri(
s3_client=mock_client,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand Down

0 comments on commit b50c557

Please sign in to comment.