diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 9df744531e..e724bbd1e7 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -178,7 +178,10 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin ) tasks: Set[str] = set() for model_id, _ in _generate_jumpstart_model_versions( - filter=filter, region=region, sagemaker_session=sagemaker_session + filter=filter, + region=region, + sagemaker_session=sagemaker_session, + model_type=JumpStartModelType.OPEN_WEIGHTS, ): _, task, _ = extract_framework_task_model(model_id) tasks.add(task) @@ -209,7 +212,10 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin ) frameworks: Set[str] = set() for model_id, _ in _generate_jumpstart_model_versions( - filter=filter, region=region, sagemaker_session=sagemaker_session + filter=filter, + region=region, + sagemaker_session=sagemaker_session, + model_type=JumpStartModelType.OPEN_WEIGHTS, ): framework, _, _ = extract_framework_task_model(model_id) frameworks.add(framework) @@ -244,7 +250,10 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin scripts: Set[str] = set() for model_id, version in _generate_jumpstart_model_versions( - filter=filter, region=region, sagemaker_session=sagemaker_session + filter=filter, + region=region, + sagemaker_session=sagemaker_session, + model_type=JumpStartModelType.OPEN_WEIGHTS, ): scripts.add(JumpStartScriptScope.INFERENCE) model_specs = verify_model_region_and_return_specs( @@ -337,6 +346,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin region: Optional[str] = None, list_incomplete_models: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: Optional[JumpStartModelType] = None, ) -> Generator: """Generate models for JumpStart, and optionally apply filters to result. @@ -370,12 +380,22 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin s3_client=sagemaker_session.s3_client, model_type=JumpStartModelType.OPEN_WEIGHTS, ) - models_manifest_list = open_weight_manifest_list + prop_models_manifest_list + models_manifest_list = ( + open_weight_manifest_list + if model_type == JumpStartModelType.OPEN_WEIGHTS + else ( + prop_models_manifest_list + if model_type == JumpStartModelType.PROPRIETARY + else open_weight_manifest_list + prop_models_manifest_list + ) + ) if isinstance(filter, str): filter = Identity(filter) - manifest_keys = set(models_manifest_list[0].__slots__ + prop_models_manifest_list[0].__slots__) + manifest_keys = set( + open_weight_manifest_list[0].__slots__ + prop_models_manifest_list[0].__slots__ + ) all_keys: Set[str] = set() diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 862d2b4174..ba11670c37 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -17,6 +17,7 @@ get_prototype_manifest, get_prototype_model_spec, ) +from tests.unit.sagemaker.jumpstart.constants import BASE_PROPRIETARY_MANIFEST from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.notebook_utils import ( _generate_jumpstart_model_versions, @@ -40,8 +41,8 @@ def test_list_jumpstart_scripts( patched_read_s3_file: Mock, ): patched_get_model_specs.side_effect = get_prototype_model_spec - patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( - region + patched_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) patched_generate_jumpstart_models.side_effect = _generate_jumpstart_model_versions patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps( @@ -63,7 +64,9 @@ def test_list_jumpstart_scripts( } assert list_jumpstart_scripts(**kwargs) == sorted(["inference", "training"]) patched_generate_jumpstart_models.assert_called_once_with( - **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION + **kwargs, + model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) assert patched_get_manifest.call_count == 2 assert patched_get_model_specs.call_count == 1 @@ -76,12 +79,15 @@ def test_list_jumpstart_scripts( "filter": "training_supported is False", "region": "sa-east-1", } + num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) assert list_jumpstart_scripts(**kwargs) == [] patched_generate_jumpstart_models.assert_called_once_with( - **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION + **kwargs, + model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) assert patched_get_manifest.call_count == 2 - assert patched_read_s3_file.call_count == 2 * len(PROTOTYPICAL_MODEL_SPECS_DICT) + assert patched_read_s3_file.call_count == num_specs @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @@ -93,8 +99,8 @@ def test_list_jumpstart_tasks( patched_get_manifest: Mock, ): patched_get_model_specs.side_effect = get_prototype_model_spec - patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( - region + patched_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) patched_generate_jumpstart_models.side_effect = _generate_jumpstart_model_versions @@ -122,7 +128,9 @@ def test_list_jumpstart_tasks( } assert list_jumpstart_tasks(**kwargs) == ["ic"] patched_generate_jumpstart_models.assert_called_once_with( - **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION + **kwargs, + model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) assert patched_get_manifest.call_count == 2 patched_get_model_specs.assert_not_called() @@ -137,8 +145,8 @@ def test_list_jumpstart_frameworks( patched_get_manifest: Mock, ): patched_get_model_specs.side_effect = get_prototype_model_spec - patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( - region + patched_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) patched_generate_jumpstart_models.side_effect = _generate_jumpstart_model_versions @@ -180,7 +188,9 @@ def test_list_jumpstart_frameworks( ) patched_generate_jumpstart_models.assert_called_once_with( - **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION + **kwargs, + model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) assert patched_get_manifest.call_count == 4 patched_get_model_specs.assert_not_called() @@ -229,8 +239,8 @@ def test_list_jumpstart_models_script_filter( patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps( get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json() ) - patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( - region + patched_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) manifest_length = len(get_prototype_manifest()) @@ -516,8 +526,8 @@ def test_list_jumpstart_models_vulnerable_models( patched_get_manifest: Mock, ): - patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( - region + patched_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) def vulnerable_inference_model_spec(bucket, key, *args, **kwargs) -> str: @@ -533,11 +543,12 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): patched_read_s3_file.side_effect = vulnerable_inference_model_spec num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) + num_prop_specs = len(BASE_PROPRIETARY_MANIFEST) assert [] == list_jumpstart_models( And("inference_vulnerable is false", "training_vulnerable is false") ) - assert patched_read_s3_file.call_count == 2 * num_specs + assert patched_read_s3_file.call_count == num_specs + num_prop_specs assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() @@ -549,7 +560,7 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): And("inference_vulnerable is false", "training_vulnerable is false") ) - assert patched_read_s3_file.call_count == 2 * num_specs + assert patched_read_s3_file.call_count == num_specs + num_prop_specs assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() @@ -567,8 +578,8 @@ def test_list_jumpstart_models_deprecated_models( patched_get_manifest: Mock, ): - patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( - region + patched_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) def deprecated_model_spec(bucket, key, *args, **kwargs) -> str: @@ -579,9 +590,10 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str: patched_read_s3_file.side_effect = deprecated_model_spec num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) + num_prop_specs = len(BASE_PROPRIETARY_MANIFEST) assert [] == list_jumpstart_models("deprecated equals false") - assert patched_read_s3_file.call_count == 2 * num_specs + assert patched_read_s3_file.call_count == num_specs + num_prop_specs assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() @@ -666,8 +678,8 @@ def test_list_jumpstart_models_complex_queries( patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps( get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json() ) - patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( - region + patched_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) assert list_jumpstart_models( @@ -711,8 +723,8 @@ def test_list_jumpstart_models_multiple_level_index( patched_get_manifest: Mock, ): patched_get_model_specs.side_effect = get_prototype_model_spec - patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( - region + patched_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) with pytest.raises(NotImplementedError): @@ -730,8 +742,8 @@ def test_get_model_url( patched_get_model_specs.side_effect = get_prototype_model_spec patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS - patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( - region + patched_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) model_id, version = "xgboost-classification-model", "1.0.0"