From a1b0dc3c2901d4103793a3601c774d039df3bb88 Mon Sep 17 00:00:00 2001 From: garywan Date: Mon, 13 Jan 2025 22:40:28 +0000 Subject: [PATCH] use jumpstart deployment config image as default optimization image --- .../serve/builder/jumpstart_builder.py | 54 ++++++++++++++++++- .../serve/test_serve_js_deep_unit_tests.py | 20 ++++++- .../serve/builder/test_js_builder.py | 11 +++- .../serve/builder/test_model_builder.py | 6 ++- 4 files changed, 86 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 37a77179cb6..efcdd5efc90 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -829,7 +829,13 @@ def _optimize_for_jumpstart( self.pysdk_model._enable_network_isolation = False if quantization_config or sharding_config or is_compilation: - return create_optimization_job_args + # only apply default image for vLLM usecases. + # vLLM does not support compilation for now so skip on compilation + return ( + create_optimization_job_args + if is_compilation + else self._set_optimization_image_default(create_optimization_job_args) + ) return None def _is_gated_model(self, model=None) -> bool: @@ -986,3 +992,49 @@ def _get_neuron_model_env_vars( ) return job_model.env return None + + def _set_optimization_image_default( + self, create_optimization_job_args: Dict[str, Any] + ) -> Dict[str, Any]: + """Defaults the optimization image to the JumpStart deployment config default + + Args: + create_optimization_job_args (Dict[str, Any]): create optimization job request + + Returns: + Dict[str, Any]: create optimization job request with image uri default + """ + default_image = self._get_default_vllm_image(self.pysdk_model.init_kwargs["image_uri"]) + + for optimization_config in create_optimization_job_args.get("OptimizationConfigs"): + if optimization_config.get("ModelQuantizationConfig"): + model_quantization_config = optimization_config.get("ModelQuantizationConfig") + if not model_quantization_config.get("Image"): + model_quantization_config["Image"] = default_image + + if optimization_config.get("ModelShardingConfig"): + model_sharding_config = optimization_config.get("ModelShardingConfig") + if not model_sharding_config.get("Image"): + model_sharding_config["Image"] = default_image + + return create_optimization_job_args + + def _get_default_vllm_image(self, image: str) -> bool: + """Ensures the minimum working image version for vLLM enabled optimization techniques + + Args: + image (str): JumpStart provided default image + + Returns: + str: minimum working image version + """ + dlc_name, dlc_tag = image.split(":") + _, lmi_version, _ = dlc_tag.split("-") + major_version, _, _ = lmi_version.split(".") + major_version_number = major_version[3:] + + if int(major_version_number) < 13: + minimum_version_default = f"{dlc_name}:0.31.0-lmi13.0.0-cu124" + logger.info(f"Defaulting to {minimum_version_default} image for optimization") + return minimum_version_default + return image diff --git a/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py b/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py index 348c57745fb..0de447d6387 100644 --- a/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py +++ b/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py @@ -18,7 +18,7 @@ from sagemaker.serve.builder.schema_builder import SchemaBuilder from sagemaker.resource_requirements import ResourceRequirements -ROLE_NAME = "SageMakerRole" +ROLE_NAME = "StandardAuthorization" def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_expected( @@ -32,6 +32,8 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e iam_client = sagemaker_session.boto_session.client("iam") role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"] + sagemaker_session.sagemaker_client.create_optimization_job = MagicMock() + schema_builder = SchemaBuilder("test", "test") model_builder = ModelBuilder( model="meta-textgeneration-llama-3-1-8b-instruct", @@ -50,6 +52,8 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e accept_eula=True, ) + assert not sagemaker_session.sagemaker_client.create_optimization_job.called + optimized_model.deploy() mock_create_model.assert_called_once_with( @@ -126,6 +130,13 @@ def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_ accept_eula=True, ) + assert ( + sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][ + "OptimizationConfigs" + ][0]["ModelShardingConfig"]["Image"] + is not None + ) + optimized_model.deploy( resources=ResourceRequirements(requests={"memory": 196608, "num_accelerators": 8}) ) @@ -206,6 +217,13 @@ def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are accept_eula=True, ) + assert ( + sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][ + "OptimizationConfigs" + ][0]["ModelQuantizationConfig"]["Image"] + is not None + ) + optimized_model.deploy() mock_create_model.assert_called_once_with( diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index b6bd69e3047..809e828f920 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -75,7 +75,7 @@ "-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" ) mock_djl_image_uri = ( - "123456789712.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.24.0-neuronx-sdk2.14.1" + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124" ) mock_model_data = { @@ -1166,6 +1166,7 @@ def test_optimize_quantize_for_jumpstart( mock_pysdk_model.image_uri = mock_tgi_image_uri mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0] + mock_pysdk_model.init_kwargs = {"image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124"} sample_input = { "inputs": "The diamondback terrapin or simply terrapin is a species " @@ -1201,6 +1202,9 @@ def test_optimize_quantize_for_jumpstart( ) self.assertIsNotNone(out_put) + self.assertEqual( + out_put["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124" + ) @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) @@ -1287,6 +1291,7 @@ def test_optimize_quantize_and_compile_for_jumpstart( mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0] mock_pysdk_model.config_name = "config_name" mock_pysdk_model._metadata_configs = {"config_name": mock_metadata_config} + mock_pysdk_model.init_kwargs = {"image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124"} sample_input = { "inputs": "The diamondback terrapin or simply terrapin is a species " @@ -1319,6 +1324,8 @@ def test_optimize_quantize_and_compile_for_jumpstart( ) self.assertIsNotNone(out_put) + self.assertIsNone(out_put["OptimizationConfigs"][1]["ModelCompilationConfig"].get("Image")) + self.assertIsNone(out_put["OptimizationConfigs"][0]["ModelQuantizationConfig"].get("Image")) @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) @@ -1640,6 +1647,7 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations( mock_lmi_js_model = MagicMock() mock_lmi_js_model.image_uri = mock_djl_image_uri + mock_lmi_js_model.init_kwargs = {"image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124"} mock_lmi_js_model.env = { "SAGEMAKER_PROGRAM": "inference.py", "ENDPOINT_SERVER_TIMEOUT": "3600", @@ -1718,6 +1726,7 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations_no_over mock_lmi_js_model = MagicMock() mock_lmi_js_model.image_uri = mock_djl_image_uri + mock_lmi_js_model.init_kwargs = {"image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124"} mock_lmi_js_model.env = { "SAGEMAKER_PROGRAM": "inference.py", "ENDPOINT_SERVER_TIMEOUT": "3600", diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 1e20bf1cf3f..c7d31335cf9 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -3733,6 +3733,7 @@ def test_optimize_sharding_with_override_for_js( pysdk_model.env = {"key": "val"} pysdk_model._enable_network_isolation = True pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None + pysdk_model.init_kwargs = {"image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124"} mock_build_for_jumpstart.side_effect = lambda **kwargs: pysdk_model mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( @@ -3803,8 +3804,9 @@ def test_optimize_sharding_with_override_for_js( OptimizationConfigs=[ { "ModelShardingConfig": { - "OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"} - } + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124", + "OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"}, + }, } ], OutputConfig={