Skip to content

Commit

Permalink
fix: JumpStart list models flaky tests (#4525)
Browse files Browse the repository at this point in the history
* fix list models flaky tests

* fix
  • Loading branch information
Captainia authored Mar 22, 2024
1 parent 5b87888 commit ca9ae23
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 31 deletions.
30 changes: 25 additions & 5 deletions src/sagemaker/jumpstart/notebook_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,10 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin
)
tasks: Set[str] = set()
for model_id, _ in _generate_jumpstart_model_versions(
filter=filter, region=region, sagemaker_session=sagemaker_session
filter=filter,
region=region,
sagemaker_session=sagemaker_session,
model_type=JumpStartModelType.OPEN_WEIGHTS,
):
_, task, _ = extract_framework_task_model(model_id)
tasks.add(task)
Expand Down Expand Up @@ -209,7 +212,10 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin
)
frameworks: Set[str] = set()
for model_id, _ in _generate_jumpstart_model_versions(
filter=filter, region=region, sagemaker_session=sagemaker_session
filter=filter,
region=region,
sagemaker_session=sagemaker_session,
model_type=JumpStartModelType.OPEN_WEIGHTS,
):
framework, _, _ = extract_framework_task_model(model_id)
frameworks.add(framework)
Expand Down Expand Up @@ -244,7 +250,10 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin

scripts: Set[str] = set()
for model_id, version in _generate_jumpstart_model_versions(
filter=filter, region=region, sagemaker_session=sagemaker_session
filter=filter,
region=region,
sagemaker_session=sagemaker_session,
model_type=JumpStartModelType.OPEN_WEIGHTS,
):
scripts.add(JumpStartScriptScope.INFERENCE)
model_specs = verify_model_region_and_return_specs(
Expand Down Expand Up @@ -337,6 +346,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
region: Optional[str] = None,
list_incomplete_models: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: Optional[JumpStartModelType] = None,
) -> Generator:
"""Generate models for JumpStart, and optionally apply filters to result.
Expand Down Expand Up @@ -370,12 +380,22 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
s3_client=sagemaker_session.s3_client,
model_type=JumpStartModelType.OPEN_WEIGHTS,
)
models_manifest_list = open_weight_manifest_list + prop_models_manifest_list
models_manifest_list = (
open_weight_manifest_list
if model_type == JumpStartModelType.OPEN_WEIGHTS
else (
prop_models_manifest_list
if model_type == JumpStartModelType.PROPRIETARY
else open_weight_manifest_list + prop_models_manifest_list
)
)

if isinstance(filter, str):
filter = Identity(filter)

manifest_keys = set(models_manifest_list[0].__slots__ + prop_models_manifest_list[0].__slots__)
manifest_keys = set(
open_weight_manifest_list[0].__slots__ + prop_models_manifest_list[0].__slots__
)

all_keys: Set[str] = set()

Expand Down
64 changes: 38 additions & 26 deletions tests/unit/sagemaker/jumpstart/test_notebook_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_prototype_manifest,
get_prototype_model_spec,
)
from tests.unit.sagemaker.jumpstart.constants import BASE_PROPRIETARY_MANIFEST
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.jumpstart.notebook_utils import (
_generate_jumpstart_model_versions,
Expand All @@ -40,8 +41,8 @@ def test_list_jumpstart_scripts(
patched_read_s3_file: Mock,
):
patched_get_model_specs.side_effect = get_prototype_model_spec
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
region
patched_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)
patched_generate_jumpstart_models.side_effect = _generate_jumpstart_model_versions
patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps(
Expand All @@ -63,7 +64,9 @@ def test_list_jumpstart_scripts(
}
assert list_jumpstart_scripts(**kwargs) == sorted(["inference", "training"])
patched_generate_jumpstart_models.assert_called_once_with(
**kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION
**kwargs,
model_type=JumpStartModelType.OPEN_WEIGHTS,
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
assert patched_get_manifest.call_count == 2
assert patched_get_model_specs.call_count == 1
Expand All @@ -76,12 +79,15 @@ def test_list_jumpstart_scripts(
"filter": "training_supported is False",
"region": "sa-east-1",
}
num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT)
assert list_jumpstart_scripts(**kwargs) == []
patched_generate_jumpstart_models.assert_called_once_with(
**kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION
**kwargs,
model_type=JumpStartModelType.OPEN_WEIGHTS,
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
assert patched_get_manifest.call_count == 2
assert patched_read_s3_file.call_count == 2 * len(PROTOTYPICAL_MODEL_SPECS_DICT)
assert patched_read_s3_file.call_count == num_specs


@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
Expand All @@ -93,8 +99,8 @@ def test_list_jumpstart_tasks(
patched_get_manifest: Mock,
):
patched_get_model_specs.side_effect = get_prototype_model_spec
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
region
patched_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)
patched_generate_jumpstart_models.side_effect = _generate_jumpstart_model_versions

Expand Down Expand Up @@ -122,7 +128,9 @@ def test_list_jumpstart_tasks(
}
assert list_jumpstart_tasks(**kwargs) == ["ic"]
patched_generate_jumpstart_models.assert_called_once_with(
**kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION
**kwargs,
model_type=JumpStartModelType.OPEN_WEIGHTS,
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
assert patched_get_manifest.call_count == 2
patched_get_model_specs.assert_not_called()
Expand All @@ -137,8 +145,8 @@ def test_list_jumpstart_frameworks(
patched_get_manifest: Mock,
):
patched_get_model_specs.side_effect = get_prototype_model_spec
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
region
patched_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)
patched_generate_jumpstart_models.side_effect = _generate_jumpstart_model_versions

Expand Down Expand Up @@ -180,7 +188,9 @@ def test_list_jumpstart_frameworks(
)

patched_generate_jumpstart_models.assert_called_once_with(
**kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION
**kwargs,
model_type=JumpStartModelType.OPEN_WEIGHTS,
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
assert patched_get_manifest.call_count == 4
patched_get_model_specs.assert_not_called()
Expand Down Expand Up @@ -229,8 +239,8 @@ def test_list_jumpstart_models_script_filter(
patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps(
get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json()
)
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
region
patched_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)

manifest_length = len(get_prototype_manifest())
Expand Down Expand Up @@ -516,8 +526,8 @@ def test_list_jumpstart_models_vulnerable_models(
patched_get_manifest: Mock,
):

patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
region
patched_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)

def vulnerable_inference_model_spec(bucket, key, *args, **kwargs) -> str:
Expand All @@ -533,11 +543,12 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
patched_read_s3_file.side_effect = vulnerable_inference_model_spec

num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT)
num_prop_specs = len(BASE_PROPRIETARY_MANIFEST)
assert [] == list_jumpstart_models(
And("inference_vulnerable is false", "training_vulnerable is false")
)

assert patched_read_s3_file.call_count == 2 * num_specs
assert patched_read_s3_file.call_count == num_specs + num_prop_specs
assert patched_get_manifest.call_count == 2

patched_get_manifest.reset_mock()
Expand All @@ -549,7 +560,7 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
And("inference_vulnerable is false", "training_vulnerable is false")
)

assert patched_read_s3_file.call_count == 2 * num_specs
assert patched_read_s3_file.call_count == num_specs + num_prop_specs
assert patched_get_manifest.call_count == 2

patched_get_manifest.reset_mock()
Expand All @@ -567,8 +578,8 @@ def test_list_jumpstart_models_deprecated_models(
patched_get_manifest: Mock,
):

patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
region
patched_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)

def deprecated_model_spec(bucket, key, *args, **kwargs) -> str:
Expand All @@ -579,9 +590,10 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str:
patched_read_s3_file.side_effect = deprecated_model_spec

num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT)
num_prop_specs = len(BASE_PROPRIETARY_MANIFEST)
assert [] == list_jumpstart_models("deprecated equals false")

assert patched_read_s3_file.call_count == 2 * num_specs
assert patched_read_s3_file.call_count == num_specs + num_prop_specs
assert patched_get_manifest.call_count == 2

patched_get_manifest.reset_mock()
Expand Down Expand Up @@ -666,8 +678,8 @@ def test_list_jumpstart_models_complex_queries(
patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps(
get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json()
)
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
region
patched_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)

assert list_jumpstart_models(
Expand Down Expand Up @@ -711,8 +723,8 @@ def test_list_jumpstart_models_multiple_level_index(
patched_get_manifest: Mock,
):
patched_get_model_specs.side_effect = get_prototype_model_spec
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
region
patched_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)

with pytest.raises(NotImplementedError):
Expand All @@ -730,8 +742,8 @@ def test_get_model_url(

patched_get_model_specs.side_effect = get_prototype_model_spec
patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest(
region
patched_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)

model_id, version = "xgboost-classification-model", "1.0.0"
Expand Down

0 comments on commit ca9ae23

Please sign in to comment.