From c622a73c6327c28123b6288d6394e4a9b4515b26 Mon Sep 17 00:00:00 2001 From: xiongz945 <54782408+xiongz945@users.noreply.github.com> Date: Thu, 7 Mar 2024 18:31:44 -0800 Subject: [PATCH 01/20] feature: Add overriding logic in ModelBuilder when task is provided (#4460) * feat: Add Optional task to Model * Revert "feat: Add Optional task to Model" This reverts commit fd3e86b19f091dc1ce4d7906644107e08768c6a4. * Add override logic in ModelBuilder with task provided * Adjusted formatting * Add extra unit tests for invalid inputs * Address PR comments * Add more test inputs to integration test * Add model_metadata field to ModelBuilder * Update doc * Update doc * Adjust formatting --------- Co-authored-by: Samrudhi Sharma Co-authored-by: Xiong Zeng --- src/sagemaker/huggingface/llm_utils.py | 3 +- src/sagemaker/serve/builder/model_builder.py | 19 ++- src/sagemaker/serve/schema/task.json | 40 +++--- .../sagemaker/serve/test_schema_builder.py | 66 +++++++++ .../serve/builder/test_model_builder.py | 125 +++++++++++++++++- 5 files changed, 226 insertions(+), 27 deletions(-) diff --git a/src/sagemaker/huggingface/llm_utils.py b/src/sagemaker/huggingface/llm_utils.py index 1a2abfb2e4..de5e624dbc 100644 --- a/src/sagemaker/huggingface/llm_utils.py +++ b/src/sagemaker/huggingface/llm_utils.py @@ -81,7 +81,8 @@ def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] = Returns: dict: The model metadata retrieved with the HuggingFace API """ - + if not model_id: + raise ValueError("Model ID is empty. Please provide a valid Model ID.") hf_model_metadata_url = f"https://huggingface.co/api/models/{model_id}" hf_model_metadata_json = None try: diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index c66057397f..4d1e51cb26 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -124,8 +124,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): into a stream. All translations between the server and the client are handled automatically with the specified input and output. model (Optional[Union[object, str]): Model object (with ``predict`` method to perform - inference) or a HuggingFace/JumpStart Model ID. Either ``model`` or - ``inference_spec`` is required for the model builder to build the artifact. + inference) or a HuggingFace/JumpStart Model ID. Either ``model`` or ``inference_spec`` + is required for the model builder to build the artifact. inference_spec (InferenceSpec): The inference spec file with your customized ``invoke`` and ``load`` functions. image_uri (Optional[str]): The container image uri (which is derived from a @@ -145,6 +145,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): to the model server). Possible values for this argument are ``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``, ``TRITON``, and``TGI``. + model_metadata (Optional[Dict[str, Any]): Dictionary used to override the HuggingFace + model metadata. Currently ``HF_TASK`` is overridable. """ model_path: Optional[str] = field( @@ -241,6 +243,10 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): model_server: Optional[ModelServer] = field( default=None, metadata={"help": "Define the model server to deploy to."} ) + model_metadata: Optional[Dict[str, Any]] = field( + default=None, + metadata={"help": "Define the model metadata to override, currently supports `HF_TASK`"}, + ) def _build_validations(self): """Placeholder docstring""" @@ -616,6 +622,9 @@ def build( # pylint: disable=R0911 self._is_custom_image_uri = self.image_uri is not None if isinstance(self.model, str): + model_task = None + if self.model_metadata: + model_task = self.model_metadata.get("HF_TASK") if self._is_jumpstart_model_id(): return self._build_for_jumpstart() if self._is_djl(): # pylint: disable=R1705 @@ -625,10 +634,10 @@ def build( # pylint: disable=R0911 self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - model_task = hf_model_md.get("pipeline_tag") - if self.schema_builder is None and model_task: + if model_task is None: + model_task = hf_model_md.get("pipeline_tag") + if self.schema_builder is None and model_task is not None: self._schema_builder_init(model_task) - if model_task == "text-generation": # pylint: disable=R1705 return self._build_for_tgi() elif self._can_fit_on_single_gpu(): diff --git a/src/sagemaker/serve/schema/task.json b/src/sagemaker/serve/schema/task.json index 1a7bdce5d0..c897f4abec 100644 --- a/src/sagemaker/serve/schema/task.json +++ b/src/sagemaker/serve/schema/task.json @@ -1,12 +1,12 @@ { "fill-mask": { - "sample_inputs": { + "sample_inputs": { "properties": { "inputs": "Paris is the [MASK] of France.", "parameters": {} } - }, - "sample_outputs": { + }, + "sample_outputs": { "properties": [ { "sequence": "Paris is the capital of France.", @@ -14,15 +14,15 @@ } ] } - }, + }, "question-answering": { - "sample_inputs": { + "sample_inputs": { "properties": { "context": "I have a German Shepherd dog, named Coco.", "question": "What is my dog's breed?" } - }, - "sample_outputs": { + }, + "sample_outputs": { "properties": [ { "answer": "German Shepherd", @@ -32,15 +32,15 @@ } ] } - }, + }, "text-classification": { - "sample_inputs": { + "sample_inputs": { "properties": { "inputs": "Where is the capital of France?, Paris is the capital of France.", "parameters": {} } - }, - "sample_outputs": { + }, + "sample_outputs": { "properties": [ { "label": "entailment", @@ -48,20 +48,20 @@ } ] } - }, - "text-generation": { - "sample_inputs": { + }, + "text-generation": { + "sample_inputs": { "properties": { "inputs": "Hello, I'm a language model", "parameters": {} } - }, - "sample_outputs": { + }, + "sample_outputs": { "properties": [ - { - "generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my" - } + { + "generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my" + } ] } - } + } } diff --git a/tests/integ/sagemaker/serve/test_schema_builder.py b/tests/integ/sagemaker/serve/test_schema_builder.py index 3816985d8f..2b6ac48460 100644 --- a/tests/integ/sagemaker/serve/test_schema_builder.py +++ b/tests/integ/sagemaker/serve/test_schema_builder.py @@ -99,3 +99,69 @@ def test_model_builder_negative_path(sagemaker_session): match="Error Message: Schema builder for text-to-image could not be found.", ): model_builder.build(sagemaker_session=sagemaker_session) + + +@pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="Testing Schema Builder Simplification feature", +) +@pytest.mark.parametrize( + "model_id, task_provided", + [ + ("bert-base-uncased", "fill-mask"), + ("bert-large-uncased-whole-word-masking-finetuned-squad", "question-answering"), + ], +) +def test_model_builder_happy_path_with_task_provided( + model_id, task_provided, sagemaker_session, gpu_instance_type +): + model_builder = ModelBuilder(model=model_id, model_metadata={"HF_TASK": task_provided}) + + model = model_builder.build(sagemaker_session=sagemaker_session) + + assert model is not None + assert model_builder.schema_builder is not None + + inputs, outputs = task.retrieve_local_schemas(task_provided) + assert model_builder.schema_builder.sample_input == inputs + assert model_builder.schema_builder.sample_output == outputs + + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + caught_ex = None + try: + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] + + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") + predictor = model.deploy( + role=role_arn, instance_count=1, instance_type=gpu_instance_type + ) + + predicted_outputs = predictor.predict(inputs) + assert predicted_outputs is not None + + except Exception as e: + caught_ex = e + finally: + cleanup_model_resources( + sagemaker_session=model_builder.sagemaker_session, + model_name=model.name, + endpoint_name=model.endpoint_name, + ) + if caught_ex: + logger.exception(caught_ex) + assert ( + False + ), f"{caught_ex} was thrown when running transformers sagemaker endpoint test" + + +def test_model_builder_negative_path_with_invalid_task(sagemaker_session): + model_builder = ModelBuilder( + model="bert-base-uncased", model_metadata={"HF_TASK": "invalid-task"} + ) + + with pytest.raises( + TaskNotFoundException, + match="Error Message: Schema builder for invalid-task could not be found.", + ): + model_builder.build(sagemaker_session=sagemaker_session) diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 1f743ff442..3b60d13dfb 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -1076,7 +1076,7 @@ def test_build_negative_path_when_schema_builder_not_present( model_builder = ModelBuilder(model="CompVis/stable-diffusion-v1-4") - self.assertRaisesRegexp( + self.assertRaisesRegex( TaskNotFoundException, "Error Message: Schema builder for text-to-image could not be found.", lambda: model_builder.build(sagemaker_session=mock_session), @@ -1593,3 +1593,126 @@ def test_total_inference_model_size_mib_throws( model_builder.build(sagemaker_session=mock_session) self.assertEqual(model_builder._can_fit_on_single_gpu(), False) + + @patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_happy_path_override_with_task_provided( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + mock_hf_model, + ): + # Setup mocks + + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + # HF Pipeline Tag + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "fill-mask"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + # HF Model config + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + + model_builder = ModelBuilder( + model="bert-base-uncased", model_metadata={"HF_TASK": "text-generation"} + ) + model_builder.build(sagemaker_session=mock_session) + + self.assertIsNotNone(model_builder.schema_builder) + sample_inputs, sample_outputs = task.retrieve_local_schemas("text-generation") + self.assertEqual( + sample_inputs["inputs"], model_builder.schema_builder.sample_input["inputs"] + ) + self.assertEqual(sample_outputs, model_builder.schema_builder.sample_output) + + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_task_override_with_invalid_task_provided( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + ): + # Setup mocks + + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + # HF Pipeline Tag + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "fill-mask"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + # HF Model config + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + model_ids_with_invalid_task = { + "bert-base-uncased": "invalid-task", + "bert-large-uncased-whole-word-masking-finetuned-squad": "", + } + for model_id in model_ids_with_invalid_task: + provided_task = model_ids_with_invalid_task[model_id] + model_builder = ModelBuilder(model=model_id, model_metadata={"HF_TASK": provided_task}) + + self.assertRaisesRegex( + TaskNotFoundException, + f"Error Message: Schema builder for {provided_task} could not be found.", + lambda: model_builder.build(sagemaker_session=mock_session), + ) + + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_task_override_with_invalid_model_provided( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_image_uris_retrieve, + ): + # Setup mocks + + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + # HF Pipeline Tag + mock_model_uris_retrieve.side_effect = KeyError + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + invalid_model_id = "" + provided_task = "fill-mask" + + model_builder = ModelBuilder( + model=invalid_model_id, model_metadata={"HF_TASK": provided_task} + ) + with self.assertRaises(Exception): + model_builder.build(sagemaker_session=mock_session) From 615a8adcbdeba0877c00123445c014f380eb2015 Mon Sep 17 00:00:00 2001 From: martinRenou Date: Fri, 8 Mar 2024 19:19:46 +0100 Subject: [PATCH 02/20] feature: Accept user-defined env variables for the entry-point (#4175) --- src/sagemaker/model.py | 8 ++--- .../test_huggingface_pytorch_compiler.py | 29 +++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index ff340b58e9..5a2b27c54d 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -766,8 +766,8 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None: def _script_mode_env_vars(self): """Returns a mapping of environment variables for script mode execution""" - script_name = None - dir_name = None + script_name = self.env.get(SCRIPT_PARAM_NAME.upper(), "") + dir_name = self.env.get(DIR_PARAM_NAME.upper(), "") if self.uploaded_code: script_name = self.uploaded_code.script_name if self.repacked_model_data or self.enable_network_isolation(): @@ -783,8 +783,8 @@ def _script_mode_env_vars(self): else "file://" + self.source_dir ) return { - SCRIPT_PARAM_NAME.upper(): script_name or str(), - DIR_PARAM_NAME.upper(): dir_name or str(), + SCRIPT_PARAM_NAME.upper(): script_name, + DIR_PARAM_NAME.upper(): dir_name, CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): to_string(self.container_log_level), SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name, } diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index 2b59113354..12162f799f 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -718,3 +718,32 @@ def test_register_hf_pytorch_model_auto_infer_framework( sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request ) + + +def test_accept_user_defined_environment_variables( + sagemaker_session, + huggingface_training_compiler_version, + huggingface_training_compiler_pytorch_version, + huggingface_training_compiler_pytorch_py_version, +): + program = "inference.py" + directory = "/opt/ml/model/code" + + hf_model = HuggingFaceModel( + model_data="s3://some/data.tar.gz", + role=ROLE, + transformers_version=huggingface_training_compiler_version, + pytorch_version=huggingface_training_compiler_pytorch_version, + py_version=huggingface_training_compiler_pytorch_py_version, + sagemaker_session=sagemaker_session, + env={ + "SAGEMAKER_PROGRAM": program, + "SAGEMAKER_SUBMIT_DIRECTORY": directory, + }, + image_uri="fakeimage", + ) + + container_env = hf_model.prepare_container_def("ml.m4.xlarge")["Environment"] + + assert container_env["SAGEMAKER_PROGRAM"] == program + assert container_env["SAGEMAKER_SUBMIT_DIRECTORY"] == directory From d3a18259da3103201aef1be2786e316e90450d6b Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Mon, 11 Mar 2024 16:42:55 -0700 Subject: [PATCH 03/20] fix: Move sagemaker pysdk version check after bootstrap in remote job (#4487) --- .../runtime_environment/bootstrap_runtime_environment.py | 7 ++++--- .../runtime_environment/runtime_environment_manager.py | 9 ++++----- .../test_bootstrap_runtime_environment.py | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py index d5d879cb08..8fd83bfcfe 100644 --- a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -65,9 +65,6 @@ def main(sys_args=None): conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env) - RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( - client_sagemaker_pysdk_version - ) user = getpass.getuser() if user != "root": @@ -89,6 +86,10 @@ def main(sys_args=None): client_python_version, conda_env, dependency_settings ) + RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( + client_sagemaker_pysdk_version + ) + exit_code = SUCCESS_EXIT_CODE except Exception as e: # pylint: disable=broad-except logger.exception("Error encountered while bootstrapping runtime environment: %s", e) diff --git a/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py b/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py index 0dd5f0d219..13493c1d15 100644 --- a/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py +++ b/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py @@ -24,8 +24,6 @@ import dataclasses import json -import sagemaker - class _UTCFormatter(logging.Formatter): """Class that overrides the default local time provider in log formatter.""" @@ -330,6 +328,7 @@ def _current_python_version(self): def _current_sagemaker_pysdk_version(self): """Returns the current sagemaker python sdk version where program is running""" + import sagemaker return sagemaker.__version__ @@ -366,10 +365,10 @@ def _validate_sagemaker_pysdk_version(self, client_sagemaker_pysdk_version): ): logger.warning( "Inconsistent sagemaker versions found: " - "sagemaker pysdk version found in the container is " + "sagemaker python sdk version found in the container is " "'%s' which does not match the '%s' on the local client. " - "Please make sure that the python version used in the training container " - "is the same as the local python version in case of unexpected behaviors.", + "Please make sure that the sagemaker version used in the training container " + "is the same as the local sagemaker version in case of unexpected behaviors.", job_sagemaker_pysdk_version, client_sagemaker_pysdk_version, ) diff --git a/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py b/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py index b7d9e10047..ef35c965e9 100644 --- a/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py +++ b/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py @@ -269,7 +269,7 @@ def test_main_failure_remote_job_with_root_user( change_dir_permission.assert_not_called() validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV) - validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION) + validate_sagemaker.assert_not_called() run_pre_exec_script.assert_not_called() bootstrap_runtime.assert_called() write_failure.assert_called_with(str(runtime_err)) @@ -317,7 +317,7 @@ def test_main_failure_pipeline_step_with_root_user( change_dir_permission.assert_not_called() validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV) - validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION) + validate_sagemaker.assert_not_called() run_pre_exec_script.assert_not_called() bootstrap_runtime.assert_called() write_failure.assert_called_with(str(runtime_err)) From 07e1b92327cfcd6c46b7220eb3972ef240404d7a Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> Date: Tue, 12 Mar 2024 08:17:36 -0700 Subject: [PATCH 04/20] change: enable github actions for PRs (#4489) * change: enable github actions for PRs * Update codebuild-ci.yml * trigger on pull_request_target * add source-version-override * fix permission --- .github/workflows/codebuild-ci.yml | 48 ++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 .github/workflows/codebuild-ci.yml diff --git a/.github/workflows/codebuild-ci.yml b/.github/workflows/codebuild-ci.yml new file mode 100644 index 0000000000..e72680be2a --- /dev/null +++ b/.github/workflows/codebuild-ci.yml @@ -0,0 +1,48 @@ +name: PR Checks +on: + pull_request_target: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.head_ref }} + cancel-in-progress: true + +permissions: + id-token: write # This is required for requesting the JWT + +jobs: + codestyle-doc-tests: + runs-on: ubuntu-latest + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} + aws-region: us-west-2 + role-duration-seconds: 10800 + - name: Run Codestyle & Doc Tests + uses: aws-actions/aws-codebuild-run-build@v1 + with: + project-name: sagemaker-python-sdk-ci-codestyle-doc-tests + source-version-override: 'pr/${{ github.event.pull_request.number }}' + unit-tests: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["py38", "py39", "py310"] + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} + aws-region: us-west-2 + role-duration-seconds: 10800 + - name: Run Unit Tests + uses: aws-actions/aws-codebuild-run-build@v1 + with: + project-name: sagemaker-python-sdk-ci-unit-tests + source-version-override: 'pr/${{ github.event.pull_request.number }}' + env-vars-for-codebuild: | + PY_VERSION + env: + PY_VERSION: ${{ matrix.python-version }} From b51a613993617d82831a328e6b741c1edced7945 Mon Sep 17 00:00:00 2001 From: mrudulmn <161017394+mrudulmn@users.noreply.github.com> Date: Wed, 13 Mar 2024 01:27:05 +0530 Subject: [PATCH 05/20] feature: Add ModelDataSource and SourceUri support for model package and while registering (#4492) Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> --- src/sagemaker/chainer/model.py | 4 + src/sagemaker/estimator.py | 3 + src/sagemaker/huggingface/model.py | 4 + src/sagemaker/jumpstart/factory/model.py | 2 + src/sagemaker/jumpstart/model.py | 4 + src/sagemaker/jumpstart/types.py | 3 + src/sagemaker/model.py | 98 +++++++- src/sagemaker/mxnet/model.py | 4 + src/sagemaker/pipeline.py | 4 + src/sagemaker/pytorch/model.py | 4 + src/sagemaker/session.py | 123 +++++++++- src/sagemaker/sklearn/model.py | 4 + src/sagemaker/tensorflow/model.py | 4 + src/sagemaker/utils.py | 18 ++ src/sagemaker/workflow/_utils.py | 4 + src/sagemaker/workflow/step_collections.py | 3 + src/sagemaker/xgboost/model.py | 4 + tests/integ/test_model_package.py | 206 ++++++++++++++++ tests/unit/sagemaker/model/test_model.py | 88 ++++++- .../sagemaker/model/test_model_package.py | 75 ++++-- tests/unit/test_session.py | 227 ++++++++++++++++++ tests/unit/test_utils.py | 13 + 22 files changed, 864 insertions(+), 35 deletions(-) diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index bafcfde3a8..9fce051454 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -174,6 +174,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -223,6 +224,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: str: A string of SageMaker Model Package ARN. @@ -262,6 +265,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def prepare_container_def( diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 7ef367e485..501c826f82 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1718,6 +1718,7 @@ def register( nearest_model_name=None, data_input_configuration=None, skip_model_validation=None, + source_uri=None, **kwargs, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -1765,6 +1766,7 @@ def register( data_input_configuration (str): Input object for the model (default: None). skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str): The URI of the source for the model package (default: None). **kwargs: Passed to invocation of ``create_model()``. Implementations may customize ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. For more, see the implementation docs. @@ -1809,6 +1811,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) @property diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index efe6a85288..f71dca0ac8 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -360,6 +360,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -410,6 +411,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -457,6 +460,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def prepare_container_def( diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 1b41cad714..40759d0f0b 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -598,6 +598,7 @@ def get_register_kwargs( nearest_model_name: Optional[str] = None, data_input_configuration: Optional[str] = None, skip_model_validation: Optional[str] = None, + source_uri: Optional[str] = None, ) -> JumpStartModelRegisterKwargs: """Returns kwargs required to call `register` on `sagemaker.estimator.Model` object.""" @@ -629,6 +630,7 @@ def get_register_kwargs( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) model_specs = verify_model_region_and_return_specs( diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 1742f860e4..8d007aed24 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -631,6 +631,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -676,6 +677,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -709,6 +712,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict()) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 810d1c4cd3..43a25f3c12 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1659,6 +1659,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "nearest_model_name", "data_input_configuration", "skip_model_validation", + "source_uri", ] SERIALIZATION_EXCLUSION_SET = { @@ -1699,6 +1700,7 @@ def __init__( nearest_model_name: Optional[str] = None, data_input_configuration: Optional[str] = None, skip_model_validation: Optional[str] = None, + source_uri: Optional[str] = None, ) -> None: """Instantiates JumpStartModelRegisterKwargs object.""" @@ -1730,3 +1732,4 @@ def __init__( self.nearest_model_name = nearest_model_name self.data_input_configuration = data_input_configuration self.skip_model_validation = skip_model_validation + self.source_uri = source_uri diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 5a2b27c54d..af08d1203f 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -77,7 +77,10 @@ ) from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker.enums import EndpointType -from sagemaker.session import get_add_model_package_inference_args +from sagemaker.session import ( + get_add_model_package_inference_args, + get_update_model_package_inference_args, +) # Setting LOGGER for backward compatibility, in case users import it... logger = LOGGER = logging.getLogger("sagemaker") @@ -423,6 +426,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -472,17 +476,14 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance or pipeline step arguments in case the Model instance is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession` """ - if isinstance(self.model_data, dict): - raise ValueError( - "SageMaker Model Package currently cannot be created with ModelDataSource." - ) - if content_types is not None: self.content_types = content_types @@ -513,6 +514,12 @@ def register( "Image": self.image_uri, } + if isinstance(self.model_data, dict): + raise ValueError( + "Un-versioned SageMaker Model Package currently cannot be " + "created with ModelDataSource." + ) + if self.model_data is not None: container_def["ModelDataUrl"] = self.model_data @@ -536,6 +543,7 @@ def register( sample_payload_url=sample_payload_url, task=task, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) model_package = self.sagemaker_session.create_model_package_from_containers( **model_pkg_args @@ -2040,8 +2048,9 @@ def __init__( endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource. - model_data (str): The S3 location of a SageMaker model data - ``.tar.gz`` file. Must be provided if algorithm_arn is provided. + model_data (str or dict[str, Any]): The S3 location of a SageMaker model data + ``.tar.gz`` file or a dictionary representing a ``ModelDataSource`` + object. Must be provided if algorithm_arn is provided. algorithm_arn (str): algorithm arn used to train the model, can be just the name if your account owns the algorithm. Must also provide ``model_data``. @@ -2050,11 +2059,6 @@ def __init__( ``model_data`` is not required. **kwargs: Additional kwargs passed to the Model constructor. """ - if isinstance(model_data, dict): - raise ValueError( - "Creating ModelPackage with ModelDataSource is currently not supported" - ) - super(ModelPackage, self).__init__( role=role, model_data=model_data, image_uri=None, **kwargs ) @@ -2222,6 +2226,74 @@ def update_customer_metadata(self, customer_metadata_properties: Dict[str, str]) sagemaker_session = self.sagemaker_session or sagemaker.Session() sagemaker_session.sagemaker_client.update_model_package(**update_metadata_args) + def update_inference_specification( + self, + containers: Dict = None, + image_uris: List[str] = None, + content_types: List[str] = None, + response_types: List[str] = None, + inference_instances: List[str] = None, + transform_instances: List[str] = None, + ): + """Inference specification to be set for the model package + + Args: + containers (dict): The Amazon ECR registry path of the Docker image + that contains the inference code. + image_uris (List[str]): The ECR path where inference code is stored. + content_types (list[str]): The supported MIME types + for the input data. + response_types (list[str]): The supported MIME types + for the output data. + inference_instances (list[str]): A list of the instance + types that are used to generate inferences in real-time (default: None). + transform_instances (list[str]): A list of the instance + types on which a transformation job can be run or on which an endpoint can be + deployed (default: None). + + """ + sagemaker_session = self.sagemaker_session or sagemaker.Session() + if (containers is not None) ^ (image_uris is None): + raise ValueError("Should have either containers or image_uris for inference.") + container_def = [] + if image_uris: + for uri in image_uris: + container_def.append( + { + "Image": uri, + } + ) + else: + container_def = containers + + model_package_update_args = get_update_model_package_inference_args( + model_package_arn=self.model_package_arn, + containers=container_def, + content_types=content_types, + response_types=response_types, + inference_instances=inference_instances, + transform_instances=transform_instances, + ) + + sagemaker_session.sagemaker_client.update_model_package(**model_package_update_args) + + def update_source_uri( + self, + source_uri: str, + ): + """Source uri to be set for the model package + + Args: + source_uri (str): The URI of the source for the model package. + + """ + update_source_uri_args = { + "ModelPackageArn": self.model_package_arn, + "SourceUri": source_uri, + } + sagemaker_session = self.sagemaker_session or sagemaker.Session() + sagemaker_session.sagemaker_client.update_model_package(**update_source_uri_args) + def remove_customer_metadata_properties( self, customer_metadata_properties_to_remove: List[str] ): diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 8cd0ac6b65..714b0db945 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -176,6 +176,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -225,6 +226,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -264,6 +267,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def prepare_container_def( diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index a4b7feac69..3bfdb1a594 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -360,6 +360,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -409,6 +410,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: If ``sagemaker_session`` is a ``PipelineSession`` instance, returns pipeline step @@ -456,6 +459,7 @@ def register( sample_payload_url=sample_payload_url, task=task, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) self.sagemaker_session.create_model_package_from_containers(**model_pkg_args) diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index fb731cabf4..f490e49375 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -178,6 +178,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -227,6 +228,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -266,6 +269,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def prepare_container_def( diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 325d3c7697..8d72051cc0 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -140,6 +140,7 @@ ) from sagemaker import exceptions from sagemaker.session_settings import SessionSettings +from sagemaker.utils import can_model_package_source_uri_autopopulate # Setting LOGGER for backward compatibility, in case users import it... logger = LOGGER = logging.getLogger("sagemaker") @@ -3969,14 +3970,19 @@ def create_model_package_from_algorithm(self, name, description, algorithm_arn, name (str): ModelPackage name description (str): Model Package description algorithm_arn (str): arn or name of the algorithm used for training. - model_data (str): s3 URI to the model artifacts produced by training + model_data (str or dict[str, Any]): s3 URI or a dictionary representing a + ``ModelDataSource`` to the model artifacts produced by training """ + sourceAlgorithm = {"AlgorithmName": algorithm_arn} + if isinstance(model_data, dict): + sourceAlgorithm["ModelDataSource"] = model_data + else: + sourceAlgorithm["ModelDataUrl"] = model_data + request = { "ModelPackageName": name, "ModelPackageDescription": description, - "SourceAlgorithmSpecification": { - "SourceAlgorithms": [{"AlgorithmName": algorithm_arn, "ModelDataUrl": model_data}] - }, + "SourceAlgorithmSpecification": {"SourceAlgorithms": [sourceAlgorithm]}, } try: logger.info("Creating model package with name: %s", name) @@ -4011,6 +4017,7 @@ def create_model_package_from_containers( sample_payload_url=None, task=None, skip_model_validation="None", + source_uri=None, ): """Get request dictionary for CreateModelPackage API. @@ -4047,6 +4054,7 @@ def create_model_package_from_containers( "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str): The URI of the source for the model package (default: None). """ if containers: # Containers are provided. Now we can merge missing entries from config. @@ -4103,6 +4111,7 @@ def create_model_package_from_containers( sample_payload_url=sample_payload_url, task=task, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def submit(request): @@ -4114,6 +4123,26 @@ def submit(request): ModelPackageGroupName=request["ModelPackageGroupName"] ) ) + if "SourceUri" in request and request["SourceUri"] is not None: + # Remove inference spec from request if the + # given source uri can lead to auto-population of it + if can_model_package_source_uri_autopopulate(request["SourceUri"]): + if "InferenceSpecification" in request: + del request["InferenceSpecification"] + return self.sagemaker_client.create_model_package(**request) + # If source uri can't autopopulate, + # first create model package with just the inference spec + # and then update model package with the source uri. + # Done this way because passing source uri and inference spec together + # in create/update model package is not allowed in the base sdk. + request_source_uri = request["SourceUri"] + del request["SourceUri"] + model_package = self.sagemaker_client.create_model_package(**request) + update_source_uri_args = { + "ModelPackageArn": model_package.get("ModelPackageArn"), + "SourceUri": request_source_uri, + } + return self.sagemaker_client.update_model_package(**update_source_uri_args) return self.sagemaker_client.create_model_package(**request) return self._intercept_create_request( @@ -6669,6 +6698,7 @@ def get_model_package_args( sample_payload_url=None, task=None, skip_model_validation=None, + source_uri=None, ): """Get arguments for create_model_package method. @@ -6707,6 +6737,7 @@ def get_model_package_args( "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str): The URI of the source for the model package (default: None). Returns: dict: A dictionary of method argument names and values. @@ -6761,6 +6792,8 @@ def get_model_package_args( model_package_args["task"] = task if skip_model_validation is not None: model_package_args["skip_model_validation"] = skip_model_validation + if source_uri is not None: + model_package_args["source_uri"] = source_uri return model_package_args @@ -6785,6 +6818,7 @@ def get_create_model_package_request( sample_payload_url=None, task=None, skip_model_validation="None", + source_uri=None, ): """Get request dictionary for CreateModelPackage API. @@ -6821,12 +6855,32 @@ def get_create_model_package_request( "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str): The URI of the source for the model package (default: None). """ if all([model_package_name, model_package_group_name]): raise ValueError( "model_package_name and model_package_group_name cannot be present at the " "same time." ) + if all([model_package_name, source_uri]): + raise ValueError( + "Un-versioned SageMaker Model Package currently cannot be " "created with source_uri." + ) + if (containers is not None) and all( + [ + model_package_name, + any( + [ + (("ModelDataSource" in c) and (c["ModelDataSource"] is not None)) + for c in containers + ] + ), + ] + ): + raise ValueError( + "Un-versioned SageMaker Model Package currently cannot be " + "created with ModelDataSource." + ) request_dict = {} if model_package_name is not None: request_dict["ModelPackageName"] = model_package_name @@ -6852,6 +6906,8 @@ def get_create_model_package_request( request_dict["SamplePayloadUrl"] = sample_payload_url if task is not None: request_dict["Task"] = task + if source_uri is not None: + request_dict["SourceUri"] = source_uri if containers is not None: inference_specification = { "Containers": containers, @@ -6900,6 +6956,65 @@ def get_create_model_package_request( return request_dict +def get_update_model_package_inference_args( + model_package_arn, + containers=None, + content_types=None, + response_types=None, + inference_instances=None, + transform_instances=None, +): + """Get request dictionary for UpdateModelPackage API for inference specification. + + Args: + model_package_arn (str): Arn for the model package. + containers (dict): The Amazon ECR registry path of the Docker image + that contains the inference code. + content_types (list[str]): The supported MIME types + for the input data. + response_types (list[str]): The supported MIME types + for the output data. + inference_instances (list[str]): A list of the instance + types that are used to generate inferences in real-time (default: None). + transform_instances (list[str]): A list of the instance + types on which a transformation job can be run or on which an endpoint can be + deployed (default: None). + """ + + request_dict = {} + if containers is not None: + inference_specification = { + "Containers": containers, + } + if content_types is not None: + inference_specification.update( + { + "SupportedContentTypes": content_types, + } + ) + if response_types is not None: + inference_specification.update( + { + "SupportedResponseMIMETypes": response_types, + } + ) + if inference_instances is not None: + inference_specification.update( + { + "SupportedRealtimeInferenceInstanceTypes": inference_instances, + } + ) + if transform_instances is not None: + inference_specification.update( + { + "SupportedTransformInstanceTypes": transform_instances, + } + ) + request_dict["InferenceSpecification"] = inference_specification + request_dict.update({"ModelPackageArn": model_package_arn}) + return request_dict + + def get_add_model_package_inference_args( model_package_arn, name, diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 195a6a3a57..27833c1d9c 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -171,6 +171,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -220,6 +221,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -259,6 +262,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def prepare_container_def( diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 1b35afbe7c..77f162207c 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -233,6 +233,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -282,6 +283,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -321,6 +324,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def deploy( diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 115b8b258d..7896aac150 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -48,6 +48,10 @@ from sagemaker.workflow.entities import PipelineVariable ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$" +MODEL_PACKAGE_ARN_PATTERN = ( + r"arn:aws([a-z\-]*)?:sagemaker:([a-z0-9\-]*):([0-9]{12}):model-package/(.*)" +) +MODEL_ARN_PATTERN = r"arn:aws([a-z\-]*):sagemaker:([a-z0-9\-]*):([0-9]{12}):model/(.*)" MAX_BUCKET_PATHS_COUNT = 5 S3_PREFIX = "s3://" HTTP_PREFIX = "http://" @@ -1581,3 +1585,17 @@ def custom_extractall_tarfile(tar, extract_path): tar.extractall(path=extract_path, filter="data") else: tar.extractall(path=extract_path, members=_get_safe_members(tar)) + + +def can_model_package_source_uri_autopopulate(source_uri: str): + """Checks if the source_uri can lead to auto-population of information in the Model registry. + + Args: + source_uri (str): The source uri. + + Returns: + bool: True if the source_uri can lead to auto-population, False otherwise. + """ + return bool( + re.match(MODEL_PACKAGE_ARN_PATTERN, source_uri) or re.match(MODEL_ARN_PATTERN, source_uri) + ) diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 1fafa646bf..841cd68083 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -328,6 +328,7 @@ def __init__( sample_payload_url=None, task=None, skip_model_validation=None, + source_uri=None, **kwargs, ): """Constructor of a register model step. @@ -379,6 +380,7 @@ def __init__( "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str): The URI of the source for the model package (default: None). **kwargs: additional arguments to `create_model`. """ super(_RegisterModelStep, self).__init__( @@ -415,6 +417,7 @@ def __init__( self.kwargs = kwargs self.container_def_list = container_def_list self.skip_model_validation = skip_model_validation + self.source_uri = source_uri self._properties = Properties( step_name=name, step=self, shape_name="DescribeModelPackageOutput" @@ -489,6 +492,7 @@ def arguments(self) -> RequestType: sample_payload_url=self.sample_payload_url, task=self.task, skip_model_validation=self.skip_model_validation, + source_uri=self.source_uri, ) request_dict = get_create_model_package_request(**model_package_args) diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index d48bf7c307..0eedf4aa96 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -96,6 +96,7 @@ def __init__( nearest_model_name=None, data_input_configuration=None, skip_model_validation=None, + source_uri=None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. @@ -153,6 +154,7 @@ def __init__( data_input_configuration (str): Input object for the model (default: None). skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str): The URI of the source for the model package (default: None). **kwargs: additional arguments to `create_model`. """ @@ -291,6 +293,7 @@ def __init__( sample_payload_url=sample_payload_url, task=task, skip_model_validation=skip_model_validation, + source_uri=source_uri, **kwargs, ) if not repack_model: diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index 74776f8f72..8101f32721 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -159,6 +159,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -208,6 +209,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: str: A string of SageMaker Model Package ARN. @@ -247,6 +250,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def prepare_container_def( diff --git a/tests/integ/test_model_package.py b/tests/integ/test_model_package.py index 1554825fc2..914c5db7ed 100644 --- a/tests/integ/test_model_package.py +++ b/tests/integ/test_model_package.py @@ -18,6 +18,8 @@ from tests.integ import DATA_DIR from sagemaker.xgboost import XGBoostModel from sagemaker import image_uris +from sagemaker.session import get_execution_role +from sagemaker.model import ModelPackage _XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone") @@ -104,3 +106,207 @@ def test_inference_specification_addition(sagemaker_session): sagemaker_session.sagemaker_client.delete_model_package_group( ModelPackageGroupName=model_group_name ) + + +def test_update_inference_specification(sagemaker_session): + model_group_name = unique_name_from_base("test-model-group") + source_uri = "dummy source uri" + + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + + model_package = sagemaker_session.sagemaker_client.create_model_package( + ModelPackageGroupName=model_group_name, SourceUri=source_uri + ) + + mp = ModelPackage( + role=get_execution_role(sagemaker_session), + model_package_arn=model_package["ModelPackageArn"], + sagemaker_session=sagemaker_session, + ) + + xgb_image = image_uris.retrieve( + "xgboost", sagemaker_session.boto_region_name, version="1", image_scope="inference" + ) + + mp.update_inference_specification(image_uris=[xgb_image]) + + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package["ModelPackageArn"] + ) + + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=model_package["ModelPackageArn"] + ) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) + + assert len(desc_model_package["InferenceSpecification"]["Containers"]) == 1 + assert desc_model_package["InferenceSpecification"]["Containers"][0]["Image"] == xgb_image + + +def test_update_source_uri(sagemaker_session): + model_group_name = unique_name_from_base("test-model-group") + source_uri = "dummy source uri" + + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + + xgb_model_data_s3 = sagemaker_session.upload_data( + path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + model = XGBoostModel( + model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session + ) + + model_package = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.m5.large"], + transform_instances=["ml.m5.large"], + model_package_group_name=model_group_name, + ) + + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + model_package.update_source_uri(source_uri=source_uri) + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + assert desc_model_package["SourceUri"] == source_uri + + +def test_clone_model_package_using_source_uri(sagemaker_session): + model_group_name = unique_name_from_base("test-model-group") + + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + + xgb_model_data_s3 = sagemaker_session.upload_data( + path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + model = XGBoostModel( + model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session + ) + + model_package = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.m5.large"], + transform_instances=["ml.m5.large"], + model_package_group_name=model_group_name, + source_uri="dummy-source-uri", + ) + + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + model2 = XGBoostModel( + model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session + ) + cloned_model_package = model2.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.m5.large"], + transform_instances=["ml.m5.large"], + model_package_group_name=model_group_name, + source_uri=model_package.model_package_arn, + ) + + desc_cloned_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=cloned_model_package.model_package_arn + ) + + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=model_package.model_package_arn + ) + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=cloned_model_package.model_package_arn + ) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) + + assert len(desc_cloned_model_package["InferenceSpecification"]["Containers"]) == len( + desc_model_package["InferenceSpecification"]["Containers"] + ) + assert len( + desc_cloned_model_package["InferenceSpecification"]["SupportedTransformInstanceTypes"] + ) == len(desc_model_package["InferenceSpecification"]["SupportedTransformInstanceTypes"]) + assert len( + desc_cloned_model_package["InferenceSpecification"][ + "SupportedRealtimeInferenceInstanceTypes" + ] + ) == len( + desc_model_package["InferenceSpecification"]["SupportedRealtimeInferenceInstanceTypes"] + ) + assert len(desc_cloned_model_package["InferenceSpecification"]["SupportedContentTypes"]) == len( + desc_model_package["InferenceSpecification"]["SupportedContentTypes"] + ) + assert len( + desc_cloned_model_package["InferenceSpecification"]["SupportedResponseMIMETypes"] + ) == len(desc_model_package["InferenceSpecification"]["SupportedResponseMIMETypes"]) + assert desc_cloned_model_package["SourceUri"] == model_package.model_package_arn + + +def test_register_model_using_source_uri(sagemaker_session): + model_name = unique_name_from_base("test-model") + model_group_name = unique_name_from_base("test-model-group") + + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + + xgb_model_data_s3 = sagemaker_session.upload_data( + path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + + model = XGBoostModel( + model_data=xgb_model_data_s3, + framework_version="1.3-1", + sagemaker_session=sagemaker_session, + role=get_execution_role(sagemaker_session), + ) + + model.name = model_name + model.create() + desc_model = sagemaker_session.sagemaker_client.describe_model(ModelName=model_name) + + model = XGBoostModel( + model_data=xgb_model_data_s3, + framework_version="1.3-1", + sagemaker_session=sagemaker_session, + role=get_execution_role(sagemaker_session), + ) + registered_model_package = model.register( + inference_instances=["ml.m5.xlarge"], + model_package_group_name=model_group_name, + source_uri=desc_model["ModelArn"], + ) + + desc_registered_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=registered_model_package.model_package_arn + ) + + sagemaker_session.sagemaker_client.delete_model(ModelName=model_name) + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=registered_model_package.model_package_arn + ) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) + + assert desc_registered_model_package["SourceUri"] == desc_model["ModelArn"] + assert "InferenceSpecification" in desc_registered_model_package + assert desc_registered_model_package["InferenceSpecification"] is not None diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index de86fcf99a..c0b18a3eb3 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -1118,7 +1118,46 @@ def test_register_calls_model_package_args(get_model_package_args, sagemaker_ses get_model_package_args""" -def test_register_calls_model_data_source_not_supported(sagemaker_session): +@patch("sagemaker.get_model_package_args") +def test_register_passes_source_uri_to_model_package_args( + get_model_package_args, sagemaker_session +): + source_dir = "s3://blah/blah/blah" + source_uri = "dummy_source_uri" + t = Model( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + source_dir=source_dir, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + ) + + t.register( + SUPPORTED_CONTENT_TYPES, + SUPPORTED_RESPONSE_MIME_TYPES, + SUPPORTED_REALTIME_INFERENCE_INSTANCE_TYPES, + SUPPORTED_BATCH_TRANSFORM_INSTANCE_TYPES, + marketplace_cert=True, + description=MODEL_DESCRIPTION, + model_package_name=MODEL_NAME, + validation_specification=VALIDATION_SPECIFICATION, + source_uri=source_uri, + ) + + # check that the kwarg source_uri was passed to the internal method 'get_model_package_args' + assert ( + "source_uri" in get_model_package_args.call_args_list[0][1] + ), "source_uri kwarg was not passed to get_model_package_args" + + # check that the kwarg source_uri is identical to the one passed into the method 'register' + assert ( + source_uri == get_model_package_args.call_args_list[0][1]["source_uri"] + ), """source_uri from model.register method is not identical to source_uri from + get_model_package_args""" + + +def test_register_with_model_data_source_not_supported_for_unversioned_model(sagemaker_session): source_dir = "s3://blah/blah/blah" t = Model( entry_point=ENTRY_POINT_INFERENCE, @@ -1137,7 +1176,7 @@ def test_register_calls_model_data_source_not_supported(sagemaker_session): with pytest.raises( ValueError, - match="SageMaker Model Package currently cannot be created with ModelDataSource.", + match="Un-versioned SageMaker Model Package currently cannot be created with ModelDataSource.", ): t.register( SUPPORTED_CONTENT_TYPES, @@ -1151,6 +1190,51 @@ def test_register_calls_model_data_source_not_supported(sagemaker_session): ) +@patch("sagemaker.get_model_package_args") +def test_register_with_model_data_source_supported_for_versioned_model( + get_model_package_args, sagemaker_session +): + source_dir = "s3://blah/blah/blah" + model_data_source = { + "S3DataSource": { + "S3Uri": "s3://bucket/model/prefix/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + t = Model( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + source_dir=source_dir, + image_uri=IMAGE_URI, + model_data=model_data_source, + ) + + t.register( + SUPPORTED_CONTENT_TYPES, + SUPPORTED_RESPONSE_MIME_TYPES, + SUPPORTED_REALTIME_INFERENCE_INSTANCE_TYPES, + SUPPORTED_BATCH_TRANSFORM_INSTANCE_TYPES, + marketplace_cert=True, + description=MODEL_DESCRIPTION, + model_package_group_name="dummy_group", + validation_specification=VALIDATION_SPECIFICATION, + ) + + # check that the kwarg container_def_list was set for the internal method 'get_model_package_args' + assert ( + "container_def_list" in get_model_package_args.call_args_list[0][1] + ), "container_def_list kwarg was not set to get_model_package_args" + + # check that the kwarg container in container_def_list contains the model data source + assert ( + model_data_source + == get_model_package_args.call_args_list[0][1]["container_def_list"][0]["ModelDataSource"] + ), """model_data_source from model.register method is not identical to ModelDataSource from + get_model_package_args""" + + @patch("sagemaker.utils.repack_model") def test_model_local_download_dir(repack_model, sagemaker_session): diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py index def7ddf5e3..9bfc830a75 100644 --- a/tests/unit/sagemaker/model/test_model_package.py +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -223,22 +223,21 @@ def test_create_sagemaker_model_include_tags(sagemaker_session): ) -def test_model_package_model_data_source_not_supported(sagemaker_session): - with pytest.raises( - ValueError, match="Creating ModelPackage with ModelDataSource is currently not supported" - ): - ModelPackage( - role="role", - model_package_arn="my-model-package", - model_data={ - "S3DataSource": { - "S3Uri": "s3://bucket/model/prefix/", - "S3DataType": "S3Prefix", - "CompressionType": "None", - } - }, - sagemaker_session=sagemaker_session, - ) +def test_model_package_model_data_source_supported(sagemaker_session): + model_data_source = { + "S3DataSource": { + "S3Uri": "s3://bucket/model/prefix/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + model_package = ModelPackage( + role="role", + model_package_arn="my-model-package", + model_data=model_data_source, + sagemaker_session=sagemaker_session, + ) + assert model_package.model_data == model_package.model_data @patch("sagemaker.utils.name_from_base") @@ -399,3 +398,47 @@ def test_add_inference_specification(sagemaker_session): } ], ) + + +def test_update_inference_specification(sagemaker_session): + model_package = ModelPackage( + role="role", + model_package_arn=MODEL_PACKAGE_VERSIONED_ARN, + sagemaker_session=sagemaker_session, + ) + + image_uris = ["image_uri"] + + containers = [{"Image": "image_uri"}] + + try: + model_package.update_inference_specification(image_uris=image_uris, containers=containers) + except ValueError as ve: + assert "Should have either containers or image_uris for inference." in str(ve) + + try: + model_package.update_inference_specification() + except ValueError as ve: + assert "Should have either containers or image_uris for inference." in str(ve) + + model_package.update_inference_specification(image_uris=image_uris) + + sagemaker_session.sagemaker_client.update_model_package.assert_called_with( + ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, + InferenceSpecification={ + "Containers": [{"Image": "image_uri"}], + }, + ) + + +def test_update_source_uri(sagemaker_session): + source_uri = "dummy_source_uri" + model_package = ModelPackage( + role="role", + model_package_arn=MODEL_PACKAGE_VERSIONED_ARN, + sagemaker_session=sagemaker_session, + ) + model_package.update_source_uri(source_uri=source_uri) + sagemaker_session.sagemaker_client.update_model_package.assert_called_with( + ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, SourceUri=source_uri + ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 358cabd0f8..ee11f5a1f3 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -5119,6 +5119,233 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_session) sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) +def test_create_model_package_from_containers_with_source_uri_and_inference_spec(sagemaker_session): + model_package_group_name = "sagemaker-model-package-group" + containers = ["dummy-container"] + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarget"] + marketplace_cert = (False,) + approval_status = ("Approved",) + skip_model_validation = "All" + source_uri = "dummy-source-uri" + + created_versioned_mp_arn = ( + "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1" + ) + sagemaker_session.sagemaker_client.create_model_package = Mock( + return_value={"ModelPackageArn": created_versioned_mp_arn} + ) + + sagemaker_session.create_model_package_from_containers( + model_package_group_name=model_package_group_name, + containers=containers, + content_types=content_types, + response_types=response_types, + inference_instances=inference_instances, + transform_instances=transform_instances, + marketplace_cert=marketplace_cert, + approval_status=approval_status, + skip_model_validation=skip_model_validation, + source_uri=source_uri, + ) + expected_create_mp_args = { + "ModelPackageGroupName": model_package_group_name, + "InferenceSpecification": { + "Containers": containers, + "SupportedContentTypes": content_types, + "SupportedResponseMIMETypes": response_types, + "SupportedRealtimeInferenceInstanceTypes": inference_instances, + "SupportedTransformInstanceTypes": transform_instances, + }, + "CertifyForMarketplace": marketplace_cert, + "ModelApprovalStatus": approval_status, + "SkipModelValidation": skip_model_validation, + } + + sagemaker_session.sagemaker_client.create_model_package.assert_called_once_with( + **expected_create_mp_args + ) + expected_update_mp_args = { + "ModelPackageArn": created_versioned_mp_arn, + "SourceUri": source_uri, + } + sagemaker_session.sagemaker_client.update_model_package.assert_called_once_with( + **expected_update_mp_args + ) + + +def test_create_model_package_from_containers_with_source_uri_for_unversioned_mp(sagemaker_session): + model_package_name = "sagemaker-model-package" + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarget"] + marketplace_cert = (False,) + approval_status = ("Approved",) + skip_model_validation = "All" + source_uri = "dummy-source-uri" + + with pytest.raises( + ValueError, + match="Un-versioned SageMaker Model Package currently cannot be created with source_uri.", + ): + sagemaker_session.create_model_package_from_containers( + model_package_name=model_package_name, + content_types=content_types, + response_types=response_types, + inference_instances=inference_instances, + transform_instances=transform_instances, + marketplace_cert=marketplace_cert, + approval_status=approval_status, + skip_model_validation=skip_model_validation, + source_uri=source_uri, + ) + + +def test_create_model_package_from_containers_with_source_uri_for_versioned_mp(sagemaker_session): + model_package_name = "sagemaker-model-package" + model_data_source = { + "S3DataSource": { + "S3Uri": "s3://bucket/model/prefix/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + containers = [{"Image": "dummy-image", "ModelDataSource": model_data_source}] + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarget"] + marketplace_cert = (False,) + approval_status = ("Approved",) + skip_model_validation = "All" + + with pytest.raises( + ValueError, + match="Un-versioned SageMaker Model Package currently cannot be created with ModelDataSource.", + ): + sagemaker_session.create_model_package_from_containers( + model_package_name=model_package_name, + containers=containers, + content_types=content_types, + response_types=response_types, + inference_instances=inference_instances, + transform_instances=transform_instances, + marketplace_cert=marketplace_cert, + approval_status=approval_status, + skip_model_validation=skip_model_validation, + ) + + +def test_create_model_package_from_containers_with_source_uri_set_to_mp(sagemaker_session): + model_package_group_name = "sagemaker-model-package-group" + containers = ["dummy-container"] + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarget"] + marketplace_cert = (False,) + approval_status = ("Approved",) + skip_model_validation = "All" + source_uri = "arn:aws:sagemaker:us-west-2:123456789123:model-package/existing-mp" + + created_versioned_mp_arn = ( + "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1" + ) + sagemaker_session.sagemaker_client.create_model_package = Mock( + return_value={"ModelPackageArn": created_versioned_mp_arn} + ) + + sagemaker_session.create_model_package_from_containers( + model_package_group_name=model_package_group_name, + containers=containers, + content_types=content_types, + response_types=response_types, + inference_instances=inference_instances, + transform_instances=transform_instances, + marketplace_cert=marketplace_cert, + approval_status=approval_status, + skip_model_validation=skip_model_validation, + source_uri=source_uri, + ) + expected_create_mp_args = { + "ModelPackageGroupName": model_package_group_name, + "CertifyForMarketplace": marketplace_cert, + "ModelApprovalStatus": approval_status, + "SkipModelValidation": skip_model_validation, + "SourceUri": source_uri, + } + + sagemaker_session.sagemaker_client.create_model_package.assert_called_once_with( + **expected_create_mp_args + ) + sagemaker_session.sagemaker_client.update_model_package.assert_not_called() + + +def test_create_model_package_from_algorithm_with_model_data_source(sagemaker_session): + model_package_name = "sagemaker-model-package" + description = "dummy description" + model_data_source = { + "S3DataSource": { + "S3Uri": "s3://bucket/model/prefix/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + algorithm_arn = "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees" + sagemaker_session.create_model_package_from_algorithm( + algorithm_arn=algorithm_arn, + model_data=model_data_source, + name=model_package_name, + description=description, + ) + expected_create_mp_args = { + "ModelPackageName": model_package_name, + "ModelPackageDescription": description, + "SourceAlgorithmSpecification": { + "SourceAlgorithms": [ + { + "AlgorithmName": algorithm_arn, + "ModelDataSource": model_data_source, + } + ] + }, + } + sagemaker_session.sagemaker_client.create_model_package.assert_called_once_with( + **expected_create_mp_args + ) + + +def test_create_model_package_from_algorithm_with_model_data_url(sagemaker_session): + model_package_name = "sagemaker-model-package" + description = "dummy description" + model_data_url = "s3://bucket/key" + algorithm_arn = "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees" + sagemaker_session.create_model_package_from_algorithm( + algorithm_arn=algorithm_arn, + model_data=model_data_url, + name=model_package_name, + description=description, + ) + expected_create_mp_args = { + "ModelPackageName": model_package_name, + "ModelPackageDescription": description, + "SourceAlgorithmSpecification": { + "SourceAlgorithms": [ + { + "AlgorithmName": algorithm_arn, + "ModelDataUrl": model_data_url, + } + ] + }, + } + sagemaker_session.sagemaker_client.create_model_package.assert_called_once_with( + **expected_create_mp_args + ) + + def test_create_model_package_from_containers_all_args(sagemaker_session): model_package_name = "sagemaker-model-package" containers = ["dummy-container"] diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 852ef8b153..a83f1b995d 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -46,6 +46,7 @@ _is_bad_path, _is_bad_link, custom_extractall_tarfile, + can_model_package_source_uri_autopopulate, ) from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.parameters import ParameterString, ParameterInteger @@ -1796,3 +1797,15 @@ def test_is_bad_link(link_name, base, expected): def test_custom_extractall_tarfile(mock_custom_tarfile, data_filter, expected_extract_path): tar = mock_custom_tarfile(data_filter) custom_extractall_tarfile(tar, "/extract/path") + + +def test_can_model_package_source_uri_autopopulate(): + test_data = [ + ("arn:aws:sagemaker:us-west-2:012345678912:model-package/dummy-mpg/1", True), + ("arn:aws:sagemaker:us-west-2:012345678912:model-package/dummy-mp", True), + ("arn:aws:sagemaker:us-west-2:012345678912:model/dummy-model", True), + ("https://path/to/model", False), + ("/home/path/to/model", False), + ] + for source_uri, expected in test_data: + assert can_model_package_source_uri_autopopulate(source_uri) == expected From 8e400e981cc9f577065257c62970b5753bb88b52 Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Tue, 12 Mar 2024 16:08:33 -0400 Subject: [PATCH 06/20] feat: support JumpStart proprietary models (#4467) * feat: add proprietary manifest/specs parsing add unittests for test_cache small refactoring address comments and more unittests fix linting and fix more tests fix: pylint feat: JumpStartModel class for prop models * remove unused imports and fix docstyle * fix: remove unused args * fix: remove unused args * fix: more unused vars * fix: slow tests * fix: unittests * added more tests to cover some lines * remove estimator warn check * chore: address comments re performance * fix: address comments * complete list experience and other fixes * fix: pylint * add doc utils and fix pylint * fix: docstyle * fix: doc * fix: default payloads * fix: doc and tags and enums * fix: jumpstart doc * rename to open_weights and fix filtering * update filter name * doc update * fix: black * rename to proprietary model and fix unittests * address comments * fix: docstyle and flake8 * address more comments and fix doc * put back doc utils for future refactoring * add prop model title in doc * doc update --------- Co-authored-by: liujiaor <128006184+liujiaorr@users.noreply.github.com> --- doc/doc_utils/jumpstart_doc_utils.py | 83 ++++- src/sagemaker/accept_types.py | 3 + src/sagemaker/base_predictor.py | 4 +- src/sagemaker/content_types.py | 3 + src/sagemaker/deserializers.py | 3 + src/sagemaker/instance_types.py | 3 + src/sagemaker/jumpstart/accessors.py | 29 +- .../jumpstart/artifacts/instance_types.py | 3 + src/sagemaker/jumpstart/artifacts/kwargs.py | 5 + .../jumpstart/artifacts/model_packages.py | 3 + src/sagemaker/jumpstart/artifacts/payloads.py | 3 + .../jumpstart/artifacts/predictors.py | 15 + .../jumpstart/artifacts/resource_names.py | 3 + .../artifacts/resource_requirements.py | 3 + src/sagemaker/jumpstart/cache.py | 249 ++++++++++---- src/sagemaker/jumpstart/constants.py | 17 +- src/sagemaker/jumpstart/enums.py | 14 + src/sagemaker/jumpstart/estimator.py | 13 +- src/sagemaker/jumpstart/exceptions.py | 70 +++- src/sagemaker/jumpstart/factory/estimator.py | 4 +- src/sagemaker/jumpstart/factory/model.py | 48 ++- src/sagemaker/jumpstart/filters.py | 24 ++ src/sagemaker/jumpstart/model.py | 74 ++++- src/sagemaker/jumpstart/notebook_utils.py | 56 +++- src/sagemaker/jumpstart/types.py | 53 ++- src/sagemaker/jumpstart/utils.py | 56 +++- src/sagemaker/model.py | 2 +- src/sagemaker/payloads.py | 7 + src/sagemaker/predictor.py | 3 + src/sagemaker/resource_requirements.py | 3 + src/sagemaker/serializers.py | 3 + .../jumpstart/model/test_jumpstart_model.py | 26 +- .../jumpstart/test_accept_types.py | 26 +- .../jumpstart/test_content_types.py | 25 +- .../jumpstart/test_deserializers.py | 25 +- .../jumpstart/test_default.py | 38 ++- .../hyperparameters/jumpstart/test_default.py | 10 +- .../jumpstart/test_validate.py | 33 +- .../image_uris/jumpstart/test_common.py | 11 +- .../jumpstart/test_instance_types.py | 14 +- tests/unit/sagemaker/jumpstart/constants.py | 81 +++++ .../jumpstart/estimator/test_estimator.py | 158 ++++----- .../estimator/test_sagemaker_config.py | 49 +-- .../sagemaker/jumpstart/model/test_model.py | 218 ++++++++----- .../jumpstart/model/test_sagemaker_config.py | 49 +-- .../sagemaker/jumpstart/test_accessors.py | 90 ++++- .../sagemaker/jumpstart/test_artifacts.py | 14 +- tests/unit/sagemaker/jumpstart/test_cache.py | 307 ++++++++++++++++-- .../sagemaker/jumpstart/test_exceptions.py | 34 ++ .../jumpstart/test_notebook_utils.py | 114 +++++-- .../sagemaker/jumpstart/test_predictor.py | 52 ++- tests/unit/sagemaker/jumpstart/test_utils.py | 94 ++++-- tests/unit/sagemaker/jumpstart/utils.py | 49 ++- .../jumpstart/test_default.py | 19 +- .../model_uris/jumpstart/test_common.py | 11 +- .../jumpstart/test_resource_requirements.py | 16 +- .../script_uris/jumpstart/test_common.py | 16 +- .../serializers/jumpstart/test_serializers.py | 15 +- 58 files changed, 1963 insertions(+), 490 deletions(-) diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index 348de7adeb..458da694d5 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -74,9 +74,12 @@ class Frameworks(str, Enum): JUMPSTART_REGION = "eu-west-2" SDK_MANIFEST_FILE = "models_manifest.json" +PROPRIETARY_SDK_MANIFEST_FILE = "proprietary-sdk-manifest.json" JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com".format( JUMPSTART_REGION, JUMPSTART_REGION ) +PROPRIETARY_DOC_BUCKET = "https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com" + TASK_MAP = { Tasks.IC: ProblemTypes.IMAGE_CLASSIFICATION, Tasks.IC_EMBEDDING: ProblemTypes.IMAGE_EMBEDDING, @@ -152,18 +155,26 @@ class Frameworks(str, Enum): } -def get_jumpstart_sdk_manifest(): - url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, SDK_MANIFEST_FILE) +def get_public_s3_json_object(url): with request.urlopen(url) as f: models_manifest = f.read().decode("utf-8") return json.loads(models_manifest) -def get_jumpstart_sdk_spec(key): - url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, key) - with request.urlopen(url) as f: - model_spec = f.read().decode("utf-8") - return json.loads(model_spec) +def get_jumpstart_sdk_manifest(): + return get_public_s3_json_object(f"{JUMPSTART_BUCKET_BASE_URL}/{SDK_MANIFEST_FILE}") + + +def get_proprietary_sdk_manifest(): + return get_public_s3_json_object(f"{PROPRIETARY_DOC_BUCKET}/{PROPRIETARY_SDK_MANIFEST_FILE}") + + +def get_jumpstart_sdk_spec(s3_key: str): + return get_public_s3_json_object(f"{JUMPSTART_BUCKET_BASE_URL}/{s3_key}") + + +def get_proprietary_sdk_spec(s3_key: str): + return get_public_s3_json_object(f"{PROPRIETARY_DOC_BUCKET}/{s3_key}") def get_model_task(id): @@ -196,6 +207,45 @@ def get_model_source(url): return "Source" +def create_proprietary_model_table(): + proprietary_content_intro = [] + proprietary_content_intro.append("\n") + proprietary_content_intro.append(".. list-table:: Available Proprietary Models\n") + proprietary_content_intro.append(" :widths: 50 20 20 20 20\n") + proprietary_content_intro.append(" :header-rows: 1\n") + proprietary_content_intro.append(" :class: datatable\n") + proprietary_content_intro.append("\n") + proprietary_content_intro.append(" * - Model ID\n") + proprietary_content_intro.append(" - Fine Tunable?\n") + proprietary_content_intro.append(" - Supported Version\n") + proprietary_content_intro.append(" - Min SDK Version\n") + proprietary_content_intro.append(" - Source\n") + + sdk_manifest = get_proprietary_sdk_manifest() + sdk_manifest_top_versions_for_models = {} + + for model in sdk_manifest: + if model["model_id"] not in sdk_manifest_top_versions_for_models: + sdk_manifest_top_versions_for_models[model["model_id"]] = model + else: + if str(sdk_manifest_top_versions_for_models[model["model_id"]]["version"]) < str( + model["version"] + ): + sdk_manifest_top_versions_for_models[model["model_id"]] = model + + proprietary_content_entries = [] + for model in sdk_manifest_top_versions_for_models.values(): + model_spec = get_proprietary_sdk_spec(model["spec_key"]) + proprietary_content_entries.append(" * - {}\n".format(model_spec["model_id"])) + proprietary_content_entries.append(" - {}\n".format(False)) # TODO: support training + proprietary_content_entries.append(" - {}\n".format(model["version"])) + proprietary_content_entries.append(" - {}\n".format(model["min_version"])) + proprietary_content_entries.append( + " - `{} <{}>`__ |external-link|\n".format("Source", model_spec.get("url")) + ) + return proprietary_content_intro + proprietary_content_entries + ["\n"] + + def create_jumpstart_model_table(): sdk_manifest = get_jumpstart_sdk_manifest() sdk_manifest_top_versions_for_models = {} @@ -249,19 +299,19 @@ def create_jumpstart_model_table(): file_content_intro.append(" - Source\n") dynamic_table_files = [] - file_content_entries = [] + open_weight_content_entries = [] for model in sdk_manifest_top_versions_for_models.values(): model_spec = get_jumpstart_sdk_spec(model["spec_key"]) model_task = get_model_task(model_spec["model_id"]) string_model_task = get_string_model_task(model_spec["model_id"]) model_source = get_model_source(model_spec["url"]) - file_content_entries.append(" * - {}\n".format(model_spec["model_id"])) - file_content_entries.append(" - {}\n".format(model_spec["training_supported"])) - file_content_entries.append(" - {}\n".format(model["version"])) - file_content_entries.append(" - {}\n".format(model["min_version"])) - file_content_entries.append(" - {}\n".format(model_task)) - file_content_entries.append( + open_weight_content_entries.append(" * - {}\n".format(model_spec["model_id"])) + open_weight_content_entries.append(" - {}\n".format(model_spec["training_supported"])) + open_weight_content_entries.append(" - {}\n".format(model["version"])) + open_weight_content_entries.append(" - {}\n".format(model["min_version"])) + open_weight_content_entries.append(" - {}\n".format(model_task)) + open_weight_content_entries.append( " - `{} <{}>`__ |external-link|\n".format(model_source, model_spec["url"]) ) @@ -299,7 +349,10 @@ def create_jumpstart_model_table(): f.writelines(file_content_single_entry) f.close() + proprietary_content_entries = create_proprietary_model_table() + f = open("doc_utils/pretrainedmodels.rst", "a") f.writelines(file_content_intro) - f.writelines(file_content_entries) + f.writelines(open_weight_content_entries) + f.writelines(proprietary_content_entries) f.close() diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index bf081365ab..78aa655e04 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -16,6 +16,7 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -75,6 +76,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the default accept type for the model matching the given arguments. @@ -114,4 +116,5 @@ def retrieve_default( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) diff --git a/src/sagemaker/base_predictor.py b/src/sagemaker/base_predictor.py index 882cfafc39..76b83c25cd 100644 --- a/src/sagemaker/base_predictor.py +++ b/src/sagemaker/base_predictor.py @@ -58,7 +58,9 @@ from sagemaker.model_monitor.model_monitoring import DEFAULT_REPOSITORY_NAME from sagemaker.lineage.context import EndpointContext -from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements +from sagemaker.compute_resource_requirements.resource_requirements import ( + ResourceRequirements, +) LOGGER = logging.getLogger("sagemaker") diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index e43e96be17..46d0361f67 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -16,6 +16,7 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -75,6 +76,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the default content type for the model matching the given arguments. @@ -114,6 +116,7 @@ def retrieve_default( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 706ae56bda..1a4be43897 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -35,6 +35,7 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -95,6 +96,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> BaseDeserializer: """Retrieves the default deserializer for the model matching the given arguments. @@ -135,4 +137,5 @@ def retrieve_default( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index 0471f374ae..48aaab0ac8 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -20,6 +20,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -34,6 +35,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the default instance type for the model matching the given arguments. @@ -85,6 +87,7 @@ def retrieve_default( tolerate_deprecated_model, sagemaker_session=sagemaker_session, training_instance_type=training_instance_type, + model_type=model_type, ) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index e03a13a7a3..35df030ddc 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -18,6 +18,7 @@ from sagemaker.deprecations import deprecated from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart import cache from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME @@ -197,7 +198,9 @@ def _set_cache_and_region(region: str, cache_kwargs: dict) -> None: @staticmethod def _get_manifest( - region: str = JUMPSTART_DEFAULT_REGION_NAME, s3_client: Optional[boto3.client] = None + region: str = JUMPSTART_DEFAULT_REGION_NAME, + s3_client: Optional[boto3.client] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest. @@ -215,13 +218,19 @@ def _get_manifest( additional_kwargs.update({"s3_client": s3_client}) cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( - {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}, region + {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}, + region, ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) - return JumpStartModelsAccessor._cache.get_manifest() # type: ignore + return JumpStartModelsAccessor._cache.get_manifest(model_type) # type: ignore @staticmethod - def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader: + def get_model_header( + region: str, + model_id: str, + version: str, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + ) -> JumpStartModelHeader: """Returns model header from JumpStart models cache. Args: @@ -234,12 +243,18 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) return JumpStartModelsAccessor._cache.get_header( # type: ignore - model_id=model_id, semantic_version_str=version + model_id=model_id, + semantic_version_str=version, + model_type=model_type, ) @staticmethod def get_model_specs( - region: str, model_id: str, version: str, s3_client: Optional[boto3.client] = None + region: str, + model_id: str, + version: str, + s3_client: Optional[boto3.client] = None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """Returns model specs from JumpStart models cache. @@ -260,7 +275,7 @@ def get_model_specs( ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) return JumpStartModelsAccessor._cache.get_specs( # type: ignore - model_id=model_id, semantic_version_str=version + model_id=model_id, version_str=version, model_type=model_type ) @staticmethod diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 38e02e3ebd..608303c5e6 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -22,6 +22,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -38,6 +39,7 @@ def _retrieve_default_instance_type( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the default instance type for the model. @@ -84,6 +86,7 @@ def _retrieve_default_instance_type( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + model_type=model_type, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index 7acad9b793..c15f686805 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -22,6 +22,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -35,6 +36,7 @@ def _retrieve_model_init_kwargs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> dict: """Retrieves kwargs for `Model`. @@ -71,6 +73,7 @@ def _retrieve_model_init_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) kwargs = deepcopy(model_specs.model_kwargs) @@ -89,6 +92,7 @@ def _retrieve_model_deploy_kwargs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> dict: """Retrieves kwargs for `Model.deploy`. @@ -128,6 +132,7 @@ def _retrieve_model_deploy_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) if volume_size_supported(instance_type) and model_specs.inference_volume_size is not None: diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index bd0ae365d9..5c8a2488c6 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -22,6 +22,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.session import Session @@ -35,6 +36,7 @@ def _retrieve_model_package_arn( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[str]: """Retrieves associated model pacakge arn for the model. @@ -74,6 +76,7 @@ def _retrieve_model_package_arn( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) if scope == JumpStartScriptScope.INFERENCE: diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 3ea2c16f80..0424145119 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -20,6 +20,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.jumpstart.types import JumpStartSerializablePayload from sagemaker.jumpstart.utils import ( @@ -35,6 +36,7 @@ def _retrieve_example_payloads( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[Dict[str, JumpStartSerializablePayload]]: """Returns example payloads. @@ -72,6 +74,7 @@ def _retrieve_example_payloads( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) default_payloads = model_specs.default_payloads diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 8d599c89cc..e9e0e8dfde 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -26,6 +26,7 @@ from sagemaker.jumpstart.enums import ( JumpStartScriptScope, MIMEType, + JumpStartModelType, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -76,6 +77,7 @@ def _retrieve_default_deserializer( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> BaseDeserializer: """Retrieves the default deserializer for the model. @@ -108,6 +110,7 @@ def _retrieve_default_deserializer( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) return _retrieve_deserializer_from_accept_type(MIMEType.from_suffixed_type(default_accept_type)) @@ -120,6 +123,7 @@ def _retrieve_default_serializer( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> BaseSerializer: """Retrieves the default serializer for the model. @@ -151,6 +155,7 @@ def _retrieve_default_serializer( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) return _retrieve_serializer_from_content_type(MIMEType.from_suffixed_type(default_content_type)) @@ -163,6 +168,7 @@ def _retrieve_deserializer_options( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[BaseDeserializer]: """Retrieves the supported deserializers for the model. @@ -194,6 +200,7 @@ def _retrieve_deserializer_options( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) seen_classes: Set[Type] = set() @@ -276,6 +283,7 @@ def _retrieve_default_content_type( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the default content type for the model. @@ -312,6 +320,7 @@ def _retrieve_default_content_type( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) default_content_type = model_specs.predictor_specs.default_content_type @@ -325,6 +334,7 @@ def _retrieve_default_accept_type( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the default accept type for the model. @@ -360,6 +370,7 @@ def _retrieve_default_accept_type( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) default_accept_type = model_specs.predictor_specs.default_accept_type @@ -374,6 +385,7 @@ def _retrieve_supported_accept_types( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[str]: """Retrieves the supported accept types for the model. @@ -409,6 +421,7 @@ def _retrieve_supported_accept_types( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) supported_accept_types = model_specs.predictor_specs.supported_accept_types @@ -423,6 +436,7 @@ def _retrieve_supported_content_types( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[str]: """Retrieves the supported content types for the model. @@ -458,6 +472,7 @@ def _retrieve_supported_content_types( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) supported_content_types = model_specs.predictor_specs.supported_content_types diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index 6b05f07b15..60af520a6e 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -19,6 +19,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -32,6 +33,7 @@ def _retrieve_resource_name_base( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> bool: """Returns default resource name. @@ -68,6 +70,7 @@ def _retrieve_resource_name_base( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + model_type=model_type, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 6ee4f31c56..9f01a7af77 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -21,6 +21,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -50,6 +51,7 @@ def _retrieve_default_resources( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, ) -> ResourceRequirements: @@ -97,6 +99,7 @@ def _retrieve_default_resources( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + model_type=model_type, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index e26d588167..7682ab3817 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -25,11 +25,17 @@ ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, + JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, MODEL_ID_LIST_WEB_URL, + MODEL_TYPE_TO_MANIFEST_MAP, + MODEL_TYPE_TO_SPECS_MAP, +) +from sagemaker.jumpstart.exceptions import ( + get_wildcard_model_version_msg, + get_wildcard_proprietary_model_version_msg, ) -from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg from sagemaker.jumpstart.parameters import ( JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, @@ -44,6 +50,7 @@ JumpStartS3FileType, JumpStartVersionedModelId, ) +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart import utils from sagemaker.utilities.cache import LRUCache @@ -68,6 +75,7 @@ def __init__( JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, manifest_file_s3_key: str = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, + proprietary_manifest_s3_key: str = JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, s3_bucket_name: Optional[str] = None, s3_client_config: Optional[botocore.config.Config] = None, s3_client: Optional[boto3.client] = None, @@ -100,14 +108,26 @@ def __init__( expiration_horizon=s3_cache_expiration_horizon, retrieval_function=self._retrieval_function, ) - self._model_id_semantic_version_manifest_key_cache = LRUCache[ + self._open_weight_model_id_manifest_key_cache = LRUCache[ JumpStartVersionedModelId, JumpStartVersionedModelId ]( max_cache_items=max_semantic_version_cache_items, expiration_horizon=semantic_version_cache_expiration_horizon, - retrieval_function=self._get_manifest_key_from_model_id_semantic_version, + retrieval_function=self._get_open_weight_manifest_key_from_model_id, + ) + self._proprietary_model_id_manifest_key_cache = LRUCache[ + JumpStartVersionedModelId, JumpStartVersionedModelId + ]( + max_cache_items=max_semantic_version_cache_items, + expiration_horizon=semantic_version_cache_expiration_horizon, + retrieval_function=self._get_proprietary_manifest_key_from_model_id, ) self._manifest_file_s3_key = manifest_file_s3_key + self._proprietary_manifest_s3_key = proprietary_manifest_s3_key + self._manifest_file_s3_map = { + JumpStartModelType.OPEN_WEIGHTS: self._manifest_file_s3_key, + JumpStartModelType.PROPRIETARY: self._proprietary_manifest_s3_key, + } self.s3_bucket_name = ( utils.get_jumpstart_content_bucket(self._region) if s3_bucket_name is None @@ -129,15 +149,40 @@ def get_region(self) -> str: """Return region for cache.""" return self._region - def set_manifest_file_s3_key(self, key: str) -> None: - """Set manifest file s3 key. Clears cache after new key is set.""" - if key != self._manifest_file_s3_key: - self._manifest_file_s3_key = key + def set_manifest_file_s3_key( + self, + key: str, + file_type: JumpStartS3FileType = JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, + ) -> None: + """Set manifest file s3 key, clear cache after new key is set. + + Raises: + ValueError: if the file type is not recognized + """ + file_mapping = { + JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: self._manifest_file_s3_key, + JumpStartS3FileType.PROPRIETARY_MANIFEST: self._proprietary_manifest_s3_key, + } + property_name = file_mapping.get(file_type) + if not property_name: + raise ValueError( + self._file_type_error_msg(file_type, manifest_only=True) + ) + if key != property_name: + setattr(self, property_name, key) self.clear() - def get_manifest_file_s3_key(self) -> str: + def get_manifest_file_s3_key( + self, file_type: JumpStartS3FileType = JumpStartS3FileType.OPEN_WEIGHT_MANIFEST + ) -> str: """Return manifest file s3 key for cache.""" - return self._manifest_file_s3_key + if file_type == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: + return self._manifest_file_s3_key + if file_type == JumpStartS3FileType.PROPRIETARY_MANIFEST: + return self._proprietary_manifest_s3_key + raise ValueError( + self._file_type_error_msg(file_type, manifest_only=True) + ) def set_s3_bucket_name(self, s3_bucket_name: str) -> None: """Set s3 bucket used for cache.""" @@ -149,10 +194,24 @@ def get_bucket(self) -> str: """Return bucket used for cache.""" return self.s3_bucket_name - def _get_manifest_key_from_model_id_semantic_version( + def _file_type_error_msg(self, file_type: str, manifest_only: bool = False) -> str: + """Return error message for bad model type.""" + if manifest_only: + return ( + f"Bad value when getting manifest '{file_type}': " + f"must be in {JumpStartS3FileType.OPEN_WEIGHT_MANIFEST} " + f"{JumpStartS3FileType.PROPRIETARY_MANIFEST}." + ) + return ( + f"Bad value when getting manifest '{file_type}': " + f"must be in '{' '.join([e.name for e in JumpStartS3FileType])}'." + ) + + def _model_id_retrieval_function( self, key: JumpStartVersionedModelId, value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 + model_type: JumpStartModelType, ) -> JumpStartVersionedModelId: """Return model ID and version in manifest that matches semantic version/id. @@ -164,6 +223,8 @@ def _get_manifest_key_from_model_id_semantic_version( key (JumpStartVersionedModelId): Key for which to fetch versioned model ID. value (Optional[JumpStartVersionedModelId]): Unused variable for current value of old cached model ID/version. + model_type (JumpStartModelType): JumpStart model type to indicate whether it is + open weights model or proprietary (Marketplace) model. Raises: KeyError: If the semantic version is not found in the manifest, or is found but @@ -171,21 +232,20 @@ def _get_manifest_key_from_model_id_semantic_version( """ model_id, version = key.model_id, key.version - + sm_version = utils.get_sagemaker_version() manifest = self._s3_cache.get( - JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + JumpStartCachedS3ContentKey( + MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]) )[0].formatted_content - sm_version = utils.get_sagemaker_version() - versions_compatible_with_sagemaker = [ - Version(header.version) + header.version for header in manifest.values() # type: ignore if header.model_id == model_id and Version(header.min_version) <= Version(sm_version) ] sm_compatible_model_version = self._select_version( - version, versions_compatible_with_sagemaker + model_id, version, versions_compatible_with_sagemaker, model_type ) if sm_compatible_model_version is not None: @@ -196,7 +256,7 @@ def _get_manifest_key_from_model_id_semantic_version( if header.model_id == model_id ] sm_incompatible_model_version = self._select_version( - version, versions_incompatible_with_sagemaker + model_id, version, versions_incompatible_with_sagemaker, model_type ) if sm_incompatible_model_version is not None: @@ -226,15 +286,27 @@ def _get_manifest_key_from_model_id_semantic_version( f"Visit {MODEL_ID_LIST_WEB_URL} for updated list of models. " ) - other_model_id_version = self._select_version( - "*", versions_incompatible_with_sagemaker - ) # all versions here are incompatible with sagemaker + other_model_id_version = None + if model_type == JumpStartModelType.OPEN_WEIGHTS: + other_model_id_version = self._select_version( + model_id, "*", versions_incompatible_with_sagemaker, model_type + ) # all versions here are incompatible with sagemaker + elif model_type == JumpStartModelType.PROPRIETARY: + all_possible_model_id_version = [ + header.version for header in manifest.values() # type: ignore + if header.model_id == model_id + ] + other_model_id_version = ( + None + if not all_possible_model_id_version + else all_possible_model_id_version[0] + ) + if other_model_id_version is not None: error_msg += ( f"Consider using model ID '{model_id}' with version " f"'{other_model_id_version}'." ) - else: possible_model_ids = [header.model_id for header in manifest.values()] # type: ignore closest_model_id = get_close_matches(model_id, possible_model_ids, n=1, cutoff=0)[0] @@ -242,6 +314,32 @@ def _get_manifest_key_from_model_id_semantic_version( raise KeyError(error_msg) + def _get_open_weight_manifest_key_from_model_id( + self, + key: JumpStartVersionedModelId, + value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 + ) -> JumpStartVersionedModelId: + """For open weights models, retrieve model manifest key for open weight model. + + Filters models list by supported versions. + """ + return self._model_id_retrieval_function( + key, value, model_type=JumpStartModelType.OPEN_WEIGHTS + ) + + def _get_proprietary_manifest_key_from_model_id( + self, + key: JumpStartVersionedModelId, + value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 + ) -> JumpStartVersionedModelId: + """For proprietary models, retrieve model manifest key for proprietary model. + + Filters models list by supported versions. + """ + return self._model_id_retrieval_function( + key, value, model_type=JumpStartModelType.PROPRIETARY + ) + def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list], str]: """Returns json file from s3, along with its etag.""" response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=key) @@ -286,11 +384,11 @@ def _get_json_file_from_local_override( filetype: JumpStartS3FileType ) -> Union[dict, list]: """Reads json file from local filesystem and returns data.""" - if filetype == JumpStartS3FileType.MANIFEST: + if filetype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: metadata_local_root = ( os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE] ) - elif filetype == JumpStartS3FileType.SPECS: + elif filetype == JumpStartS3FileType.OPEN_WEIGHT_SPECS: metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE] else: raise ValueError(f"Unsupported file type for local override: {filetype}") @@ -318,8 +416,10 @@ def _retrieval_function( """ file_type, s3_key = key.file_type, key.s3_key - - if file_type == JumpStartS3FileType.MANIFEST: + if file_type in { + JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, + JumpStartS3FileType.PROPRIETARY_MANIFEST, + }: if value is not None and not self._is_local_metadata_mode(): etag = self._get_json_md5_hash(s3_key) if etag == value.md5_hash: @@ -329,27 +429,36 @@ def _retrieval_function( formatted_content=utils.get_formatted_manifest(formatted_body), md5_hash=etag, ) - if file_type == JumpStartS3FileType.SPECS: + if file_type in { + JumpStartS3FileType.OPEN_WEIGHT_SPECS, + JumpStartS3FileType.PROPRIETARY_SPECS, + }: formatted_body, _ = self._get_json_file(s3_key, file_type) model_specs = JumpStartModelSpecs(formatted_body) utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client) - return JumpStartCachedS3ContentValue( - formatted_content=model_specs - ) + return JumpStartCachedS3ContentValue(formatted_content=model_specs) raise ValueError( - f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}" + self._file_type_error_msg(file_type) ) - def get_manifest(self) -> List[JumpStartModelHeader]: + def get_manifest( + self, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + ) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest.""" - manifest_dict = self._s3_cache.get( - JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + JumpStartCachedS3ContentKey( + MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]) )[0].formatted_content manifest = list(manifest_dict.values()) # type: ignore return manifest - def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModelHeader: + def get_header( + self, + model_id: str, + semantic_version_str: str, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + ) -> JumpStartModelHeader: """Return header for a given JumpStart model ID and semantic version. Args: @@ -358,29 +467,43 @@ def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModel header. """ - return self._get_header_impl(model_id, semantic_version_str=semantic_version_str) + return self._get_header_impl( + model_id, semantic_version_str=semantic_version_str, model_type=model_type + ) def _select_version( self, - semantic_version_str: str, - available_versions: List[Version], + model_id: str, + version_str: str, + available_versions: List[str], + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[str]: """Perform semantic version search on available versions. Args: - semantic_version_str (str): the semantic version for which to filter + version_str (str): the semantic version for which to filter available versions. available_versions (List[Version]): list of available versions. """ - if semantic_version_str == "*": + + if version_str == "*": if len(available_versions) == 0: return None return str(max(available_versions)) + if model_type == JumpStartModelType.PROPRIETARY: + if "*" in version_str: + raise KeyError( + get_wildcard_proprietary_model_version_msg( + model_id, version_str, available_versions + ) + ) + return version_str if version_str in available_versions else None + try: - spec = SpecifierSet(f"=={semantic_version_str}") + spec = SpecifierSet(f"=={version_str}") except InvalidSpecifier: - raise KeyError(f"Bad semantic version: {semantic_version_str}") + raise KeyError(f"Bad semantic version: {version_str}") available_versions_filtered = list(spec.filter(available_versions)) return ( str(max(available_versions_filtered)) if available_versions_filtered != [] else None @@ -391,6 +514,7 @@ def _get_header_impl( model_id: str, semantic_version_str: str, attempt: int = 0, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS ) -> JumpStartModelHeader: """Lower-level function to return header. @@ -402,14 +526,20 @@ def _get_header_impl( header. attempt (int): attempt number at retrieving a header. """ - - versioned_model_id = self._model_id_semantic_version_manifest_key_cache.get( - JumpStartVersionedModelId(model_id, semantic_version_str) - )[0] + if model_type == JumpStartModelType.OPEN_WEIGHTS: + versioned_model_id = self._open_weight_model_id_manifest_key_cache.get( + JumpStartVersionedModelId(model_id, semantic_version_str) + )[0] + elif model_type == JumpStartModelType.PROPRIETARY: + versioned_model_id = self._proprietary_model_id_manifest_key_cache.get( + JumpStartVersionedModelId(model_id, semantic_version_str) + )[0] manifest = self._s3_cache.get( - JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + JumpStartCachedS3ContentKey( + MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]) )[0].formatted_content + try: header = manifest[versioned_model_id] # type: ignore return header @@ -417,28 +547,34 @@ def _get_header_impl( if attempt > 0: raise self.clear() - return self._get_header_impl(model_id, semantic_version_str, attempt + 1) + return self._get_header_impl(model_id, semantic_version_str, attempt + 1, model_type) - def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelSpecs: + def get_specs( + self, + model_id: str, + version_str: str, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS + ) -> JumpStartModelSpecs: """Return specs for a given JumpStart model ID and semantic version. Args: model_id (str): model ID for which to get specs. semantic_version_str (str): The semantic version for which to get specs. + model_type (JumpStartModelType): The type of the model of interest. """ - - header = self.get_header(model_id, semantic_version_str) + header = self.get_header(model_id, version_str, model_type) spec_key = header.spec_key specs, cache_hit = self._s3_cache.get( - JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key) + JumpStartCachedS3ContentKey( + MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key + ) ) - if not cache_hit and "*" in semantic_version_str: + + if not cache_hit and "*" in version_str: JUMPSTART_LOGGER.warning( get_wildcard_model_version_msg( - header.model_id, - semantic_version_str, - header.version + header.model_id, version_str, header.version ) ) return specs.formatted_content @@ -446,4 +582,5 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS def clear(self) -> None: """Clears the model ID/version and s3 cache.""" self._s3_cache.clear() - self._model_id_semantic_version_manifest_key_cache.clear() + self._open_weight_model_id_manifest_key_cache.clear() + self._proprietary_model_id_manifest_key_cache.clear() diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 2e655ac285..be66e8968e 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -22,8 +22,9 @@ SerializerType, DeserializerType, MIMEType, + JumpStartModelType, ) -from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo +from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo, JumpStartS3FileType from sagemaker.base_serializers import ( BaseSerializer, CSVSerializer, @@ -169,6 +170,7 @@ JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2" JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" +JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json" INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py" TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py" @@ -188,6 +190,9 @@ SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY = "SageMakerGatedModelS3Uri" +PROPRIETARY_MODEL_SPEC_PREFIX = "proprietary-models" +PROPRIETARY_MODEL_FILTER_NAME = "marketplace" + CONTENT_TYPE_TO_SERIALIZER_TYPE_MAP: Dict[MIMEType, SerializerType] = { MIMEType.X_IMAGE: SerializerType.RAW_BYTES, MIMEType.LIST_TEXT: SerializerType.JSON, @@ -213,6 +218,16 @@ DeserializerType.JSON: JSONDeserializer, } +MODEL_TYPE_TO_MANIFEST_MAP: Dict[Type[JumpStartModelType], Type[JumpStartS3FileType]] = { + JumpStartModelType.OPEN_WEIGHTS: JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, + JumpStartModelType.PROPRIETARY: JumpStartS3FileType.PROPRIETARY_MANIFEST, +} + +MODEL_TYPE_TO_SPECS_MAP: Dict[Type[JumpStartModelType], Type[JumpStartS3FileType]] = { + JumpStartModelType.OPEN_WEIGHTS: JumpStartS3FileType.OPEN_WEIGHT_SPECS, + JumpStartModelType.PROPRIETARY: JumpStartS3FileType.PROPRIETARY_SPECS, +} + MODEL_ID_LIST_WEB_URL = "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html" JUMPSTART_LOGGER = logging.getLogger("sagemaker.jumpstart") diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index e33daca046..afe0df0dfe 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -34,6 +34,19 @@ class ModelFramework(str, Enum): SKLEARN = "sklearn" +class JumpStartModelType(str, Enum): + """Enum class for JumpStart model type. + + OPEN_WEIGHTS: Publicly available models have open weights + and are onboarded and maintained by JumpStart. + PROPRIETARY: Proprietary models from third-party providers do not have open weights. + You must subscribe to proprietary models in AWS Marketplace before use. + """ + + OPEN_WEIGHTS = "open_weights" + PROPRIETARY = "proprietary" + + class VariableScope(str, Enum): """Possible value of the ``scope`` attribute for a hyperparameter or environment variable. @@ -78,6 +91,7 @@ class JumpStartTag(str, Enum): MODEL_ID = "sagemaker-sdk:jumpstart-model-id" MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version" + MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type" class SerializerType(str, Enum): diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 24105c4369..4dada409f5 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -35,7 +35,7 @@ from sagemaker.jumpstart.factory.model import get_default_predictor from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job from sagemaker.jumpstart.utils import ( - is_valid_model_id, + validate_model_id_and_get_type, resolve_model_sagemaker_config_field, ) from sagemaker.utils import stringify_object, format_tags, Tags @@ -504,8 +504,8 @@ def __init__( ValueError: If the model ID is not recognized by JumpStart. """ - def _is_valid_model_id_hook(): - return is_valid_model_id( + def _validate_model_id_and_get_type_hook(): + return validate_model_id_and_get_type( model_id=model_id, model_version=model_version, region=region, @@ -513,14 +513,17 @@ def _is_valid_model_id_hook(): sagemaker_session=sagemaker_session, ) - if not _is_valid_model_id_hook(): + self.model_type = _validate_model_id_and_get_type_hook() + if not self.model_type: JumpStartModelsAccessor.reset_cache() - if not _is_valid_model_id_hook(): + self.model_type = _validate_model_id_and_get_type_hook() + if not self.model_type: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) estimator_init_kwargs = get_init_kwargs( model_id=model_id, model_version=model_version, + model_type=self.model_type, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, role=role, diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index c55c9081cb..742a6b8d3f 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -14,7 +14,12 @@ from __future__ import absolute_import from typing import List, Optional -from sagemaker.jumpstart.constants import MODEL_ID_LIST_WEB_URL, JumpStartScriptScope +from botocore.exceptions import ClientError + +from sagemaker.jumpstart.constants import ( + MODEL_ID_LIST_WEB_URL, + JumpStartScriptScope, +) NO_AVAILABLE_INSTANCES_ERROR_MSG = ( "No instances available in {region} that can support model ID '{model_id}'. " @@ -28,7 +33,7 @@ INVALID_MODEL_ID_ERROR_MSG = ( "Invalid model ID: '{model_id}'. Please visit " - f"{MODEL_ID_LIST_WEB_URL} for list of supported model IDs. " + f"{MODEL_ID_LIST_WEB_URL} for a list of valid model IDs. " "The module `sagemaker.jumpstart.notebook_utils` contains utilities for " "fetching model IDs. We recommend upgrading to the latest version of sagemaker " "to get access to the most models." @@ -58,6 +63,32 @@ def get_wildcard_model_version_msg( ) +def get_proprietary_model_subscription_msg( + model_id: str, + subscription_link: str, +) -> str: + """Returns customer-facing message for using a proprietary model.""" + + return ( + f"INFO: Using proprietary model '{model_id}'. " + f"To subscribe to this model in AWS Marketplace, see {subscription_link}" + ) + + +def get_wildcard_proprietary_model_version_msg( + model_id: str, wildcard_model_version: str, available_versions: List[str] +) -> str: + """Returns customer-facing message for passing wildcard version to proprietary models.""" + msg = ( + f"Proprietary model '{model_id}' does not support " + f"wildcard version identifier '{wildcard_model_version}'. " + ) + if len(available_versions) > 0: + msg += f"You can pin to version '{available_versions[0]}'. " + msg += f"{MODEL_ID_LIST_WEB_URL} for a list of valid model IDs. " + return msg + + def get_old_model_version_msg( model_id: str, current_model_version: str, latest_model_version: str ) -> str: @@ -70,6 +101,15 @@ def get_old_model_version_msg( ) +def get_proprietary_model_subscription_error(error: ClientError, subscription_link: str) -> None: + """Returns customer-facing message associated with a Marketplace subscription error.""" + + error_code = error.response["Error"]["Code"] + error_message = error.response["Error"]["Message"] + if error_code == "ValidationException" and "not subscribed" in error_message: + raise MarketplaceModelSubscriptionError(subscription_link) + + class JumpStartHyperparametersError(ValueError): """Exception raised for bad hyperparameters of a JumpStart model.""" @@ -169,3 +209,29 @@ def __init__( ) super().__init__(self.message) + + +class MarketplaceModelSubscriptionError(ValueError): + """Exception raised when trying to deploy a JumpStart Marketplace model. + + A caller is required to subscribe to the Marketplace product in order to deploy. + This exception is raised when a caller tries to deploy a JumpStart Marketplace model + but the caller is not subscribed to the model. + """ + + def __init__( + self, + model_subscription_link: Optional[str] = None, + message: Optional[str] = None, + ): + if message: + self.message = message + else: + self.message = ( + "To use a proprietary JumpStart model, " + "you must first subscribe to the model in AWS Marketplace. " + ) + if model_subscription_link: + self.message += f"To subscribe to this model, see {model_subscription_link}" + + super().__init__(self.message) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 7ccf57983b..7c20c281f5 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -50,7 +50,7 @@ TRAINING_ENTRY_POINT_SCRIPT_NAME, SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, ) -from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.factory import model from sagemaker.jumpstart.types import ( JumpStartEstimatorDeployKwargs, @@ -77,6 +77,7 @@ def get_init_kwargs( model_id: str, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -134,6 +135,7 @@ def get_init_kwargs( estimator_init_kwargs: JumpStartEstimatorInitKwargs = JumpStartEstimatorInitKwargs( model_id=model_id, model_version=model_version, + model_type=model_type, role=role, region=region, instance_count=instance_count, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 40759d0f0b..63b4898877 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -37,7 +37,7 @@ from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines -from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.types import ( JumpStartModelDeployKwargs, JumpStartModelInitKwargs, @@ -71,6 +71,7 @@ def get_default_predictor( tolerate_vulnerable_model: bool, tolerate_deprecated_model: bool, sagemaker_session: Session, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Predictor: """Converts predictor returned from ``Model.deploy()`` into a JumpStart-specific one. @@ -92,6 +93,7 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) predictor.deserializer = deserializers.retrieve_default( model_id=model_id, @@ -100,6 +102,7 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) predictor.accept = accept_types.retrieve_default( model_id=model_id, @@ -108,6 +111,7 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) predictor.content_type = content_types.retrieve_default( model_id=model_id, @@ -116,6 +120,7 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) return predictor @@ -187,6 +192,7 @@ def _add_instance_type_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, training_instance_type=kwargs.training_instance_type, + model_type=kwargs.model_type, ) if not disable_instance_type_logging and orig_instance_type is None: @@ -199,7 +205,14 @@ def _add_instance_type_to_kwargs( def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: - """Sets image uri based on default or override, returns full kwargs.""" + """Sets image uri based on default or override, returns full kwargs. + + Uses placeholder image uri for JumpStart proprietary models that uses ModelPackages + """ + + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + kwargs.image_uri = None + return kwargs kwargs.image_uri = kwargs.image_uri or image_uris.retrieve( region=kwargs.region, @@ -219,6 +232,10 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets model data based on default or override, returns full kwargs.""" + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + kwargs.model_data = None + return kwargs + model_data: Union[str, dict] = kwargs.model_data or model_uris.retrieve( model_scope=JumpStartScriptScope.INFERENCE, model_id=kwargs.model_id, @@ -255,6 +272,10 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets source dir based on default or override, returns full kwargs.""" + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + kwargs.source_dir = None + return kwargs + source_dir = kwargs.source_dir if _model_supports_inference_script_uri( @@ -283,6 +304,10 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets entry point based on default or override, returns full kwargs.""" + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + kwargs.entry_point = None + return kwargs + entry_point = kwargs.entry_point if _model_supports_inference_script_uri( @@ -304,6 +329,10 @@ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets env based on default or override, returns full kwargs.""" + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + kwargs.env = None + return kwargs + env = kwargs.env if env is None: @@ -348,6 +377,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) kwargs.model_package_arn = model_package_arn @@ -364,6 +394,7 @@ def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelI tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) for key, value in model_kwargs_to_add.items(): @@ -399,6 +430,7 @@ def _add_endpoint_name_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) kwargs.endpoint_name = kwargs.endpoint_name or ( @@ -420,6 +452,7 @@ def _add_model_name_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) kwargs.name = kwargs.name or ( @@ -440,11 +473,12 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ).version if kwargs.sagemaker_session.settings.include_jumpstart_tags: kwargs.tags = add_jumpstart_model_id_version_tags( - kwargs.tags, kwargs.model_id, full_model_version + kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type ) return kwargs @@ -461,6 +495,7 @@ def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any] tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) for key, value in deploy_kwargs_to_add.items(): @@ -481,6 +516,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, instance_type=kwargs.instance_type, ) @@ -490,6 +526,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -522,6 +559,7 @@ def get_deploy_kwargs( deploy_kwargs: JumpStartModelDeployKwargs = JumpStartModelDeployKwargs( model_id=model_id, model_version=model_version, + model_type=model_type, region=region, initial_instance_count=initial_instance_count, instance_type=instance_type, @@ -657,6 +695,7 @@ def get_init_kwargs( model_id: str, model_from_estimator: bool = False, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, instance_type: Optional[str] = None, @@ -688,6 +727,7 @@ def get_init_kwargs( model_init_kwargs: JumpStartModelInitKwargs = JumpStartModelInitKwargs( model_id=model_id, model_version=model_version, + model_type=model_type, instance_type=instance_type, region=region, image_uri=image_uri, @@ -732,14 +772,12 @@ def get_init_kwargs( # we use the model artifact from the training job output if not model_from_estimator: model_init_kwargs = _add_model_data_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_source_dir_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_entry_point_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_env_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_predictor_cls_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_extra_model_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_role_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs) diff --git a/src/sagemaker/jumpstart/filters.py b/src/sagemaker/jumpstart/filters.py index b045435ed0..fc5113315d 100644 --- a/src/sagemaker/jumpstart/filters.py +++ b/src/sagemaker/jumpstart/filters.py @@ -49,6 +49,14 @@ class SpecialSupportedFilterKeys(str, Enum): TASK = "task" FRAMEWORK = "framework" + MODEL_TYPE = "model_type" + + +class ProprietaryModelFilterIdentifiers(str, Enum): + """Enum class for proprietary model filter keys.""" + + PROPRIETARY = "proprietary" + MARKETPLACE = "marketplace" FILTER_OPERATOR_STRING_MAPPINGS = { @@ -429,6 +437,22 @@ def __init__(self, key: str, value: str, operator: str): self.value = value self.operator = operator + def set_key(self, key: str) -> None: + """Sets the key for the model filter. + + Args: + key (str): The key to be set. + """ + self.key = key + + def set_value(self, value: str) -> None: + """Sets the value for the model filter. + + Args: + value (str): The value to be set. + """ + self.value = value + def parse_filter_string(filter_string: str) -> ModelFilter: """Parse filter string and return a serialized ``ModelFilter`` object. diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 8d007aed24..a8f4c0e9fc 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -15,14 +15,21 @@ from __future__ import absolute_import from typing import Dict, List, Optional, Union +from botocore.exceptions import ClientError + from sagemaker import payloads from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer from sagemaker.base_serializers import BaseSerializer +from sagemaker.enums import EndpointType from sagemaker.explainer.explainer_config import ExplainerConfig from sagemaker.jumpstart.accessors import JumpStartModelsAccessor from sagemaker.jumpstart.enums import JumpStartScriptScope -from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG +from sagemaker.jumpstart.exceptions import ( + INVALID_MODEL_ID_ERROR_MSG, + get_proprietary_model_subscription_error, + get_proprietary_model_subscription_msg, +) from sagemaker.jumpstart.factory.model import ( get_default_predictor, get_deploy_kwargs, @@ -30,7 +37,12 @@ get_register_kwargs, ) from sagemaker.jumpstart.types import JumpStartSerializablePayload -from sagemaker.jumpstart.utils import is_valid_model_id +from sagemaker.jumpstart.utils import ( + validate_model_id_and_get_type, + verify_model_region_and_return_specs, +) +from sagemaker.jumpstart.constants import JUMPSTART_LOGGER +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.utils import stringify_object, format_tags, Tags from sagemaker.model import ( Model, @@ -45,7 +57,6 @@ from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements -from sagemaker.enums import EndpointType class JumpStartModel(Model): @@ -270,8 +281,8 @@ def __init__( ValueError: If the model ID is not recognized by JumpStart. """ - def _is_valid_model_id_hook(): - return is_valid_model_id( + def _validate_model_id_and_type(): + return validate_model_id_and_get_type( model_id=model_id, model_version=model_version, region=region, @@ -279,16 +290,18 @@ def _is_valid_model_id_hook(): sagemaker_session=sagemaker_session, ) - if not _is_valid_model_id_hook(): + self.model_type = _validate_model_id_and_type() + if not self.model_type: JumpStartModelsAccessor.reset_cache() - if not _is_valid_model_id_hook(): + self.model_type = _validate_model_id_and_type() + if not self.model_type: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) self._model_data_is_set = model_data is not None - model_init_kwargs = get_init_kwargs( model_id=model_id, model_from_estimator=False, + model_type=self.model_type, model_version=model_version, instance_type=instance_type, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -326,10 +339,27 @@ def _is_valid_model_id_hook(): self.region = model_init_kwargs.region self.sagemaker_session = model_init_kwargs.sagemaker_session + if self.model_type == JumpStartModelType.PROPRIETARY: + self.log_subscription_warning() + super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) self.model_package_arn = model_init_kwargs.model_package_arn + def log_subscription_warning(self) -> None: + """Log message prompting the customer to subscribe to the proprietary model.""" + subscription_link = verify_model_region_and_return_specs( + region=self.region, + model_id=self.model_id, + version=self.model_version, + model_type=self.model_type, + scope=JumpStartScriptScope.INFERENCE, + sagemaker_session=self.sagemaker_session, + ).model_subscription_link + JUMPSTART_LOGGER.warning( + get_proprietary_model_subscription_msg(self.model_id, subscription_link) + ) + def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]: """Returns all example payloads associated with the model. @@ -347,6 +377,7 @@ def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]: tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, + model_type=self.model_type, ) def retrieve_example_payload(self) -> JumpStartSerializablePayload: @@ -364,6 +395,7 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload: return payloads.retrieve_example( model_id=self.model_id, model_version=self.model_version, + model_type=self.model_type, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, @@ -558,6 +590,9 @@ def deploy( endpoint. endpoint_type (EndpointType): The type of endpoint used to deploy models. (Default: EndpointType.MODEL_BASED). + + Raises: + MarketplaceModelSubscriptionError: If the caller is not subscribed to the model. """ deploy_kwargs = get_deploy_kwargs( @@ -589,9 +624,29 @@ def deploy( resources=resources, managed_instance_scaling=managed_instance_scaling, endpoint_type=endpoint_type, + model_type=self.model_type, ) + if ( + self.model_type == JumpStartModelType.PROPRIETARY + and endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED + ): + raise ValueError( + f"{EndpointType.INFERENCE_COMPONENT_BASED} is not supported for Proprietary models." + ) - predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict()) + try: + predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict()) + except ClientError as e: + subscription_link = verify_model_region_and_return_specs( + region=self.region, + model_id=self.model_id, + version=self.model_version, + model_type=self.model_type, + scope=JumpStartScriptScope.INFERENCE, + sagemaker_session=self.sagemaker_session, + ).model_subscription_link + get_proprietary_model_subscription_error(e, subscription_link) + raise # If no predictor class was passed, add defaults to predictor if self.orig_predictor_cls is None and async_inference_config is None: @@ -603,6 +658,7 @@ def deploy( tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, + model_type=self.model_type, ) # If a predictor class was passed, do not mutate predictor diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 1554025995..485354e802 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -24,10 +24,12 @@ from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, + PROPRIETARY_MODEL_SPEC_PREFIX, ) -from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.filters import ( SPECIAL_SUPPORTED_FILTER_KEYS, + ProprietaryModelFilterIdentifiers, BooleanValues, Identity, SpecialSupportedFilterKeys, @@ -38,6 +40,7 @@ get_jumpstart_content_bucket, get_sagemaker_version, verify_model_region_and_return_specs, + validate_model_id_and_get_type, ) from sagemaker.session import Session @@ -124,15 +127,11 @@ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]: Args: model_id (str): The model ID for which to extract the framework/task/model. - - Raises: - ValueError: If the model ID cannot be parsed into at least 3 components seperated by - "-" character. """ _id_parts = model_id.split("-") if len(_id_parts) < 3: - raise ValueError(f"incorrect model ID: {model_id}.") + return "", "", "" framework = _id_parts[0] task = _id_parts[1] @@ -141,6 +140,20 @@ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]: return framework, task, name +def extract_model_type_filter_representation(spec_key: str) -> str: + """Parses model spec key, determine if the model is proprietary or open weight. + + Args: + spek_key (str): The model spec key for which to extract the model type. + """ + model_spec_prefix = spec_key.split("/")[0] + + if model_spec_prefix == PROPRIETARY_MODEL_SPEC_PREFIX: + return JumpStartModelType.PROPRIETARY.value + + return JumpStartModelType.OPEN_WEIGHTS.value + + def list_jumpstart_tasks( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), region: str = JUMPSTART_DEFAULT_REGION_NAME, @@ -321,14 +334,22 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ - models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( - region=region, s3_client=sagemaker_session.s3_client + prop_models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( + region=region, + s3_client=sagemaker_session.s3_client, + model_type=JumpStartModelType.PROPRIETARY, ) + open_weight_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( + region=region, + s3_client=sagemaker_session.s3_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + ) + models_manifest_list = open_weight_manifest_list + prop_models_manifest_list if isinstance(filter, str): filter = Identity(filter) - manifest_keys = set(models_manifest_list[0].__slots__) + manifest_keys = set(models_manifest_list[0].__slots__ + prop_models_manifest_list[0].__slots__) all_keys: Set[str] = set() @@ -338,6 +359,10 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin model_filter = operator.unresolved_value key = model_filter.key all_keys.add(key) + if model_filter.key == SpecialSupportedFilterKeys.MODEL_TYPE and model_filter.value in { + identifier.value for identifier in ProprietaryModelFilterIdentifiers + }: + model_filter.set_value(JumpStartModelType.PROPRIETARY.value) model_filters.add(model_filter) for key in all_keys: @@ -351,6 +376,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin is_task_filter = SpecialSupportedFilterKeys.TASK in all_keys is_framework_filter = SpecialSupportedFilterKeys.FRAMEWORK in all_keys + is_model_type_filter = SpecialSupportedFilterKeys.MODEL_TYPE in all_keys def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, str]]: @@ -373,6 +399,11 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, SpecialSupportedFilterKeys.FRAMEWORK ] = extract_framework_task_model(model_manifest.model_id)[0] + if is_model_type_filter: + manifest_specs_cached_values[ + SpecialSupportedFilterKeys.MODEL_TYPE + ] = extract_model_type_filter_representation(model_manifest.spec_key) + if Version(model_manifest.min_version) > Version(get_sagemaker_version()): return None @@ -466,6 +497,12 @@ def get_model_url( sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to retrieve the model url. """ + model_type = validate_model_id_and_get_type( + model_id=model_id, + model_version=model_version, + region=region, + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( region=region, @@ -473,5 +510,6 @@ def get_model_url( version=model_version, sagemaker_session=sagemaker_session, scope=JumpStartScriptScope.INFERENCE, + model_type=model_type, ) return model_specs.url diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 43a25f3c12..3142596eba 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -19,6 +19,7 @@ from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session from sagemaker.workflow.entities import PipelineVariable @@ -102,8 +103,10 @@ def __repr__(self) -> str: class JumpStartS3FileType(str, Enum): """Type of files published in JumpStart S3 distribution buckets.""" - MANIFEST = "manifest" - SPECS = "specs" + OPEN_WEIGHT_MANIFEST = "manifest" + OPEN_WEIGHT_SPECS = "specs" + PROPRIETARY_MANIFEST = "proptietary_manifest" + PROPRIETARY_SPECS = "proprietary_specs" class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): @@ -788,6 +791,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType): "training_instance_type_variants", "default_payloads", "gated_bucket", + "model_subscription_link", ] def __init__(self, spec: Dict[str, Any]): @@ -805,29 +809,31 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj (Dict[str, Any]): Dictionary representation of spec. """ self.model_id: str = json_obj["model_id"] - self.url: str = json_obj["url"] + self.url: str = json_obj.get("url", "") self.version: str = json_obj["version"] self.min_sdk_version: str = json_obj["min_sdk_version"] - self.incremental_training_supported: bool = bool(json_obj["incremental_training_supported"]) + self.incremental_training_supported: bool = bool( + json_obj.get("incremental_training_supported", False) + ) self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = ( JumpStartECRSpecs(json_obj["hosting_ecr_specs"]) if "hosting_ecr_specs" in json_obj else None ) - self.hosting_artifact_key: str = json_obj["hosting_artifact_key"] - self.hosting_script_key: str = json_obj["hosting_script_key"] - self.training_supported: bool = bool(json_obj["training_supported"]) + self.hosting_artifact_key: Optional[str] = json_obj.get("hosting_artifact_key") + self.hosting_script_key: Optional[str] = json_obj.get("hosting_script_key") + self.training_supported: Optional[bool] = bool(json_obj.get("training_supported", False)) self.inference_environment_variables = [ JumpStartEnvironmentVariable(env_variable) - for env_variable in json_obj["inference_environment_variables"] + for env_variable in json_obj.get("inference_environment_variables", []) ] - self.inference_vulnerable: bool = bool(json_obj["inference_vulnerable"]) - self.inference_dependencies: List[str] = json_obj["inference_dependencies"] - self.inference_vulnerabilities: List[str] = json_obj["inference_vulnerabilities"] - self.training_vulnerable: bool = bool(json_obj["training_vulnerable"]) - self.training_dependencies: List[str] = json_obj["training_dependencies"] - self.training_vulnerabilities: List[str] = json_obj["training_vulnerabilities"] - self.deprecated: bool = bool(json_obj["deprecated"]) + self.inference_vulnerable: bool = bool(json_obj.get("inference_vulnerable", False)) + self.inference_dependencies: List[str] = json_obj.get("inference_dependencies", []) + self.inference_vulnerabilities: List[str] = json_obj.get("inference_vulnerabilities", []) + self.training_vulnerable: bool = bool(json_obj.get("training_vulnerable", False)) + self.training_dependencies: List[str] = json_obj.get("training_dependencies", []) + self.training_vulnerabilities: List[str] = json_obj.get("training_vulnerabilities", []) + self.deprecated: bool = bool(json_obj.get("deprecated", False)) self.deprecated_message: Optional[str] = json_obj.get("deprecated_message") self.deprecate_warn_message: Optional[str] = json_obj.get("deprecate_warn_message") self.usage_info_message: Optional[str] = json_obj.get("usage_info_message") @@ -917,6 +923,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if json_obj.get("training_instance_type_variants") else None ) + self.model_subscription_link = json_obj.get("model_subscription_link") def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartModelSpecs object.""" @@ -1049,6 +1056,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "model_type", "instance_type", "tolerate_vulnerable_model", "tolerate_deprecated_model", @@ -1079,6 +1087,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "instance_type", "model_id", "model_version", + "model_type", "tolerate_vulnerable_model", "tolerate_deprecated_model", "region", @@ -1090,6 +1099,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, instance_type: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, @@ -1119,6 +1129,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.model_type = model_type self.instance_type = instance_type self.region = region self.image_uri = image_uri @@ -1151,6 +1162,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "model_type", "initial_instance_count", "instance_type", "region", @@ -1182,6 +1194,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): SERIALIZATION_EXCLUSION_SET = { "model_id", "model_version", + "model_type", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1193,6 +1206,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -1224,6 +1238,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.model_type = model_type self.initial_instance_count = initial_instance_count self.instance_type = instance_type self.region = region @@ -1258,6 +1273,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "model_type", "instance_type", "instance_count", "region", @@ -1317,12 +1333,14 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "tolerate_vulnerable_model", "model_id", "model_version", + "model_type", } def __init__( self, model_id: str, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, role: Optional[str] = None, @@ -1379,6 +1397,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.model_type = (model_type,) self.instance_type = instance_type self.instance_count = instance_count self.region = region @@ -1440,6 +1459,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "model_type", "region", "inputs", "wait", @@ -1454,6 +1474,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): SERIALIZATION_EXCLUSION_SET = { "model_id", "model_version", + "model_type", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1464,6 +1485,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, inputs: Optional[Union[str, Dict, Any, Any]] = None, wait: Optional[bool] = None, @@ -1478,6 +1500,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.model_type = model_type self.region = region self.inputs = inputs self.wait = wait diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 2621422811..71a8067a6f 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging import os -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Set, Optional, Tuple, Union from urllib.parse import urlparse import boto3 from packaging.version import Version @@ -26,6 +26,7 @@ TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, TRAINING_JOB_ROLE_ARN_PATH, ) + from sagemaker.jumpstart import constants, enums from sagemaker.jumpstart import accessors from sagemaker.s3 import parse_s3_url @@ -314,6 +315,7 @@ def add_single_jumpstart_tag( ( tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags) or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags) + or tag_key_in_array(enums.JumpStartTag.MODEL_TYPE, curr_tags) ) if is_uri else False @@ -348,6 +350,7 @@ def add_jumpstart_model_id_version_tags( tags: Optional[List[TagsDict]], model_id: str, model_version: str, + model_type: Optional[enums.JumpStartModelType] = None, ) -> List[TagsDict]: """Add custom model ID and version tags to JumpStart related resources.""" if model_id is None or model_version is None: @@ -364,6 +367,13 @@ def add_jumpstart_model_id_version_tags( tags, is_uri=False, ) + if model_type == enums.JumpStartModelType.PROPRIETARY: + tags = add_single_jumpstart_tag( + enums.JumpStartModelType.PROPRIETARY.value, + enums.JumpStartTag.MODEL_TYPE, + tags, + is_uri=False, + ) return tags @@ -530,6 +540,7 @@ def verify_model_region_and_return_specs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """Verifies that an acceptable model_id, version, scope, and region combination is provided. @@ -578,6 +589,7 @@ def verify_model_region_and_return_specs( model_id=model_id, version=version, s3_client=sagemaker_session.s3_client, + model_type=model_type, ) if ( @@ -732,36 +744,52 @@ def resolve_estimator_sagemaker_config_field( return field_val -def is_valid_model_id( +def validate_model_id_and_get_type( model_id: Optional[str], region: Optional[str] = None, model_version: Optional[str] = None, script: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> bool: - """Returns True if the model ID is supported for the given script. +) -> Optional[enums.JumpStartModelType]: + """Returns model type if the model ID is supported for the given script. Raises: ValueError: If the script is not supported by JumpStart. """ + + def _get_model_type( + model_id: str, + open_weights_model_ids: Set[str], + proprietary_model_ids: Set[str], + script: enums.JumpStartScriptScope, + ) -> Optional[enums.JumpStartModelType]: + if model_id in open_weights_model_ids: + return enums.JumpStartModelType.OPEN_WEIGHTS + if model_id in proprietary_model_ids: + if script == enums.JumpStartScriptScope.INFERENCE: + return enums.JumpStartModelType.PROPRIETARY + raise ValueError(f"Unsupported script for Marketplace models: {script}") + return None + if model_id in {None, ""}: - return False + return None if not isinstance(model_id, str): - return False + return None s3_client = sagemaker_session.s3_client if sagemaker_session else None region = region or constants.JUMPSTART_DEFAULT_REGION_NAME model_version = model_version or "*" - models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( - region=region, s3_client=s3_client + region=region, s3_client=s3_client, model_type=enums.JumpStartModelType.OPEN_WEIGHTS + ) + open_weight_model_id_set = {model.model_id for model in models_manifest_list} + + proprietary_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( + region=region, s3_client=s3_client, model_type=enums.JumpStartModelType.PROPRIETARY ) - model_id_set = {model.model_id for model in models_manifest_list} - if script == enums.JumpStartScriptScope.INFERENCE: - return model_id in model_id_set - if script == enums.JumpStartScriptScope.TRAINING: - return model_id in model_id_set - raise ValueError(f"Unsupported script: {script}") + + proprietary_model_id_set = {model.model_id for model in proprietary_manifest_list} + return _get_model_type(model_id, open_weight_model_id_set, proprietary_model_id_set, script) def get_jumpstart_model_id_version_from_resource_arn( diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index af08d1203f..c5f8b86f60 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -141,7 +141,7 @@ class Model(ModelBase, InferenceRecommenderMixin): def __init__( self, - image_uri: Union[str, PipelineVariable], + image_uri: Optional[Union[str, PipelineVariable]] = None, model_data: Optional[Union[str, PipelineVariable, dict]] = None, role: Optional[str] = None, predictor_cls: Optional[callable] = None, diff --git a/src/sagemaker/payloads.py b/src/sagemaker/payloads.py index 52d633ed4e..de33f61b82 100644 --- a/src/sagemaker/payloads.py +++ b/src/sagemaker/payloads.py @@ -21,6 +21,7 @@ from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.payload_utils import PayloadSerializer from sagemaker.jumpstart.types import JumpStartSerializablePayload +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -31,6 +32,7 @@ def retrieve_all_examples( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, serialize: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -83,6 +85,7 @@ def retrieve_all_examples( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) if unserialized_payload_dict is None: @@ -120,6 +123,7 @@ def retrieve_example( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, serialize: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -133,6 +137,8 @@ def retrieve_example( the model payload. model_version (str): The version of the JumpStart model for which to retrieve the model payload. + model_type (str): The model type of the JumpStart model, either is open weight + or proprietary. serialize (bool): Whether to serialize byte-stream valued payloads by downloading binary files from s3 and applying encoding, or to keep payload in pre-serialized state. Set this option to False if you want to avoid s3 downloads or if you @@ -162,6 +168,7 @@ def retrieve_example( region=region, model_id=model_id, model_version=model_version, + model_type=model_type, serialize=serialize, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 42c2af0917..6f846bba65 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -15,6 +15,7 @@ from typing import Optional from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.factory.model import get_default_predictor from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint @@ -41,6 +42,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Predictor: """Retrieves the default predictor for the model matching the given arguments. @@ -107,4 +109,5 @@ def retrieve_default( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) diff --git a/src/sagemaker/resource_requirements.py b/src/sagemaker/resource_requirements.py index 93b2833a35..df14ac558f 100644 --- a/src/sagemaker/resource_requirements.py +++ b/src/sagemaker/resource_requirements.py @@ -21,6 +21,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session LOGGER = logging.getLogger("sagemaker") @@ -33,6 +34,7 @@ def retrieve_default( scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, ) -> ResourceRequirements: @@ -82,6 +84,7 @@ def retrieve_default( region, tolerate_vulnerable_model, tolerate_deprecated_model, + model_type=model_type, sagemaker_session=sagemaker_session, instance_type=instance_type, ) diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index fc76c0fa76..aefb52bd97 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -34,6 +34,7 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -93,6 +94,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> BaseSerializer: """Retrieves the default serializer for the model matching the given arguments. @@ -134,4 +136,5 @@ def retrieve_default( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 24050807cc..5205765e2f 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -256,9 +256,33 @@ def test_jumpstart_model_register(setup): predictor = model_package.deploy( instance_type="ml.p3.2xlarge", initial_instance_count=1, - tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], ) response = predictor.predict("hello world!") assert response is not None + + +@pytest.mark.skipif( + True, + reason="Only enable if test account is subscribed to the proprietary model", +) +def test_proprietary_jumpstart_model(setup): + + model_id = "ai21-jurassic-2-light" + + model = JumpStartModel( + model_id=model_id, + model_version="2.0.004", + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + + predictor = model.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}] + ) + payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1} + + response = predictor.predict(payload) + + assert response is not None diff --git a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py index 28211d06f1..49c18beec2 100644 --- a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py +++ b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py @@ -17,21 +17,27 @@ from sagemaker import accept_types from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec + mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_accept_types( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -45,18 +51,26 @@ def test_jumpstart_default_accept_types( assert default_accept_type == "application/json" patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_supported_accept_types( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -73,5 +87,9 @@ def test_jumpstart_supported_accept_types( ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) diff --git a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py index 4b2db7d7f4..7765d6eaad 100644 --- a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py +++ b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py @@ -17,6 +17,7 @@ from sagemaker import content_types from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec @@ -24,14 +25,18 @@ mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_content_types( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -45,18 +50,26 @@ def test_jumpstart_default_content_types( assert default_content_type == "application/x-text" patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_supported_content_types( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -72,5 +85,9 @@ def test_jumpstart_supported_content_types( ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) diff --git a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py index 9d6e2f21de..5328533da5 100644 --- a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py +++ b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py @@ -18,6 +18,7 @@ from sagemaker import base_deserializers, deserializers from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec @@ -26,14 +27,18 @@ mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_deserializers( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -47,18 +52,26 @@ def test_jumpstart_default_deserializers( assert isinstance(default_deserializer, base_deserializers.JSONDeserializer) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_deserializer_options( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -79,5 +92,9 @@ def test_jumpstart_deserializer_options( ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) diff --git a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py index acd8d19923..38cc5ebbf3 100644 --- a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py +++ b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py @@ -18,17 +18,23 @@ import pytest from sagemaker import environment_variables +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec + mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_default_environment_variables(patched_get_model_specs): +def test_jumpstart_default_environment_variables( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id = "pytorch-eqa-bert-base-cased" region = "us-west-2" @@ -48,7 +54,11 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="*", s3_client=mock_client + region=region, + model_id=model_id, + version="*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -68,7 +78,11 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="1.*", s3_client=mock_client + region=region, + model_id=model_id, + version="1.*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -98,10 +112,14 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): ) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_sdk_environment_variables(patched_get_model_specs): +def test_jumpstart_sdk_environment_variables( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id = "pytorch-eqa-bert-base-cased" region = "us-west-2" @@ -122,7 +140,11 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="*", s3_client=mock_client + region=region, + model_id=model_id, + version="*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -143,7 +165,11 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="1.*", s3_client=mock_client + region=region, + model_id=model_id, + version="1.*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py index eebc079164..a13fba87ae 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py @@ -18,6 +18,7 @@ import pytest from sagemaker import hyperparameters +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec @@ -26,10 +27,14 @@ mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_default_hyperparameters(patched_get_model_specs): +def test_jumpstart_default_hyperparameters( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id = "pytorch-eqa-bert-base-cased" region = "us-west-2" @@ -47,6 +52,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): model_id=model_id, version="*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -64,6 +70,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): model_id=model_id, version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -89,6 +96,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): model_id=model_id, version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index 0054ed9dbd..7a5df4ac93 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -17,7 +17,7 @@ import pytest import boto3 from sagemaker import hyperparameters -from sagemaker.jumpstart.enums import HyperparameterValidationMode +from sagemaker.jumpstart.enums import HyperparameterValidationMode, JumpStartModelType from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError from sagemaker.jumpstart.types import JumpStartHyperparameter @@ -27,8 +27,11 @@ mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_validate_provided_hyperparameters(patched_get_model_specs): +def test_jumpstart_validate_provided_hyperparameters( + patched_get_model_specs, patched_validate_model_id_and_get_type +): def add_options_to_hyperparameter(*largs, **kwargs): spec = get_spec_from_base_spec(*largs, **kwargs) spec.hyperparameters.extend( @@ -109,6 +112,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): return spec patched_get_model_specs.side_effect = add_options_to_hyperparameter + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "pytorch-eqa-bert-base-cased", "*" region = "us-west-2" @@ -140,6 +144,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -398,8 +403,11 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_validate_algorithm_hyperparameters(patched_get_model_specs): +def test_jumpstart_validate_algorithm_hyperparameters( + patched_get_model_specs, patched_validate_model_id_and_get_type +): def add_options_to_hyperparameter(*largs, **kwargs): spec = get_spec_from_base_spec(*largs, **kwargs) spec.hyperparameters.append( @@ -416,6 +424,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): return spec patched_get_model_specs.side_effect = add_options_to_hyperparameter + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "pytorch-eqa-bert-base-cased", "*" region = "us-west-2" @@ -437,7 +446,11 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -464,10 +477,14 @@ def add_options_to_hyperparameter(*largs, **kwargs): assert str(e.value) == "Cannot find algorithm hyperparameter for 'adam-learning-rate'." +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs): +def test_jumpstart_validate_all_hyperparameters( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "pytorch-eqa-bert-base-cased", "*" region = "us-west-2" @@ -491,7 +508,11 @@ def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs): ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index 8a41891280..6c80c97f33 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -18,19 +18,24 @@ from sagemaker import image_uris from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.image_uris.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_common_image_uri( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -49,6 +54,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -69,6 +75,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -89,6 +96,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -109,6 +117,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index bed2e50674..982c7f1702 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -18,14 +18,17 @@ import pytest from sagemaker import instance_types +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_instance_types(patched_get_model_specs): +def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_model_id_and_get_type): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "huggingface-eqa-bert-base-cased", "*" region = "us-west-2" @@ -47,6 +50,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -65,6 +69,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -89,6 +94,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -111,7 +117,11 @@ def test_jumpstart_instance_types(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 605253466a..ce8cc4ddfa 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -6113,6 +6113,7 @@ "deprecated_message": None, "hosting_model_package_arns": None, "hosting_eula_key": None, + "model_subscription_link": None, "hyperparameters": [ { "name": "epochs", @@ -6309,3 +6310,83 @@ "imagenet-inception-v3-classification-4/specs_v3.0.0.json", }, ] + +BASE_PROPRIETARY_HEADER = { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], +} + +BASE_PROPRIETARY_MANIFEST = [ + { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], + }, + { + "model_id": "lighton-mini-instruct40b", + "version": "v1.0", + "min_version": "2.0.0", + "spec_key": "proprietary-models/lighton-mini-instruct40b/proprietary_specs_v1.0.json", + "search_keywords": ["Text2Text", "Generation"], + }, + { + "model_id": "ai21-paraphrase", + "version": "1.0.005", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-paraphrase/proprietary_specs_1.0.005.json", + "search_keywords": ["Text2Text", "Generation"], + }, +] + +BASE_PROPRIETARY_SPEC = { + "model_id": "ai21-jurassic-2-light", + "version": "2.0.004", + "min_sdk_version": "2.999.0", + "listing_id": "prodview-roz6zicyvi666", + "product_id": "1bd680a0-f29b-479d-91c3-9899743021cf", + "model_subscription_link": "https://aws.amazon.com/marketplace/ai/procurement?productId=1bd680a0", + "hosting_notebook_key": "pmm-notebooks/pmm-notebook-ai21-jurassic-2-light.ipynb", + "deploy_kwargs": { + "model_data_download_timeout": 3600, + "container_startup_health_check_timeout": 600, + }, + "default_payloads": { + "Shakespeare": { + "content_type": "application/json", + "prompt_key": "prompt", + "output_keys": {"generated_text": "[0].completions[0].data.text"}, + "body": {"prompt": "To be, or", "maxTokens": 1, "temperature": 0}, + } + }, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "default_inference_instance_type": "ml.p4de.24xlarge", + "supported_inference_instance_types": ["ml.p4de.24xlarge"], + "hosting_model_package_arns": { + "us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/j2-light-v2-0-004", + "us-east-2": "arn:aws:sagemaker:us-east-2:057799348421:model-package/j2-light-v2-0-004", + "us-west-1": "arn:aws:sagemaker:us-west-1:382657785993:model-package/j2-light-v2-0-004", + "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/j2-light-v2-0-004", + "ca-central-1": "arn:aws:sagemaker:ca-central-1:470592106596:model-package/j2-light-v2-0-004", + "eu-central-1": "arn:aws:sagemaker:eu-central-1:446921602837:model-package/j2-light-v2-0-004", + "eu-west-1": "arn:aws:sagemaker:eu-west-1:985815980388:model-package/j2-light-v2-0-004", + "eu-west-2": "arn:aws:sagemaker:eu-west-2:856760150666:model-package/j2-light-v2-0-004", + "eu-west-3": "arn:aws:sagemaker:eu-west-3:843114510376:model-package/j2-light-v2-0-004", + "eu-north-1": "arn:aws:sagemaker:eu-north-1:136758871317:model-package/j2-light-v2-0-004", + "ap-southeast-1": "arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/j2-light-v2-0-004", + "ap-southeast-2": "arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/j2-light-v2-0-004", + "ap-northeast-2": "arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/j2-light-v2-0-004", + "ap-northeast-1": "arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/j2-light-v2-0-004", + "ap-south-1": "arn:aws:sagemaker:ap-south-1:077584701553:model-package/j2-light-v2-0-004", + "sa-east-1": "arn:aws:sagemaker:sa-east-1:270155090741:model-package/j2-light-v2-0-004", + }, +} diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 4dc35b65ca..4fa18f31aa 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -31,7 +31,7 @@ _retrieve_default_training_metric_definitions, ) from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag, JumpStartModelType from sagemaker.jumpstart.estimator import JumpStartEstimator @@ -60,9 +60,10 @@ class EstimatorTest(unittest.TestCase): @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER") @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -75,14 +76,15 @@ def test_non_prepacked( mock_estimator_fit: mock.Mock, mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, + mock_get_model_type: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, mock_jumpstart_model_factory_logger: mock.Mock, mock_jumpstart_estimator_factory_logger: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_sagemaker_timestamp.return_value = "9876" @@ -92,6 +94,8 @@ def test_non_prepacked( mock_get_model_specs.side_effect = get_special_model_spec + mock_get_model_type.return_value = JumpStartModelType.OPEN_WEIGHTS + mock_session_estimator.return_value = sagemaker_session mock_session_model.return_value = sagemaker_session @@ -182,7 +186,7 @@ def test_non_prepacked( ], ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -199,11 +203,11 @@ def test_prepacked( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -282,7 +286,7 @@ def test_prepacked( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -299,14 +303,14 @@ def test_gated_model_s3_uri( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_timestamp: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor mock_timestamp.return_value = "8675309" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-gated-artifact-trainable-model", "*" @@ -418,7 +422,7 @@ def test_gated_model_s3_uri( "sagemaker.jumpstart.artifacts.environment_variables.get_jumpstart_gated_content_bucket" ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -435,7 +439,7 @@ def test_gated_model_non_model_package_s3_uri( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_timestamp: mock.Mock, mock_get_jumpstart_gated_content_bucket: mock.Mock, ): @@ -444,7 +448,7 @@ def test_gated_model_non_model_package_s3_uri( mock_get_jumpstart_gated_content_bucket.return_value = "top-secret-private-models-bucket" mock_timestamp.return_value = "8675309" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-gated-artifact-non-model-package-trainable-model", "*" @@ -566,7 +570,7 @@ def test_gated_model_non_model_package_s3_uri( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -583,15 +587,13 @@ def test_jumpstart_model_package_artifact_s3_uri_unsupported_region( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_timestamp: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor mock_timestamp.return_value = "8675309" - mock_is_valid_model_id.return_value = True - model_id, _ = "js-gated-artifact-trainable-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -599,6 +601,8 @@ def test_jumpstart_model_package_artifact_s3_uri_unsupported_region( mock_session_estimator.return_value = sagemaker_session mock_session_model.return_value = sagemaker_session + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS + with pytest.raises(ValueError) as e: JumpStartEstimator(model_id=model_id, region="eu-north-1") @@ -608,7 +612,7 @@ def test_jumpstart_model_package_artifact_s3_uri_unsupported_region( "us-west-2, us-east-1, eu-west-1, ap-southeast-1." ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -621,10 +625,10 @@ def test_deprecated( mock_estimator_fit: mock.Mock, mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "deprecated_model", "*" @@ -642,7 +646,7 @@ def test_deprecated( JumpStartEstimator(model_id=model_id, tolerate_deprecated_model=True).fit(channels).deploy() - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -655,9 +659,9 @@ def test_vulnerable( mock_estimator_fit: mock.Mock, mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "vulnerable_model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -758,7 +762,7 @@ def test_estimator_use_kwargs(self): @mock.patch("sagemaker.jumpstart.factory.estimator.metric_definitions_utils.retrieve_default") @mock.patch("sagemaker.jumpstart.factory.estimator.environment_variables.retrieve_default") @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -775,7 +779,7 @@ def evaluate_estimator_workflow_with_kwargs( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_timestamp: mock.Mock, mock_retrieve_default_environment_variables: mock.Mock, mock_retrieve_metric_definitions: mock.Mock, @@ -806,7 +810,7 @@ def evaluate_estimator_workflow_with_kwargs( mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "js-trainable-model", "*" @@ -908,16 +912,16 @@ def evaluate_estimator_workflow_with_kwargs( mock_estimator_deploy.assert_called_once_with(**expected_deploy_kwargs) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_estimator_tags_disabled( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -947,16 +951,16 @@ def test_jumpstart_estimator_tags_disabled( [{"Key": "blah", "Value": "blahagain"}], ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_estimator_tags( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -989,18 +993,18 @@ def test_jumpstart_estimator_tags( @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_estimator_attach_no_model_id_happy_case( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, get_model_id_version_from_training_job: mock.Mock, mock_attach: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS get_model_id_version_from_training_job.return_value = ( "js-trainable-model-prepacked", @@ -1032,18 +1036,18 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_estimator_attach_no_model_id_sad_case( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, get_model_id_version_from_training_job: mock.Mock, mock_attach: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS get_model_id_version_from_training_job.side_effect = ValueError() @@ -1115,22 +1119,22 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): @mock.patch("sagemaker.jumpstart.estimator.get_init_kwargs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") - def test_is_valid_model_id( + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") + def test_validate_model_id_and_get_type( self, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_init: mock.Mock, mock_get_init_kwargs: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS JumpStartEstimator(model_id="valid_model_id") - mock_is_valid_model_id.return_value = False + mock_validate_model_id_and_get_type.return_value = False with pytest.raises(ValueError): JumpStartEstimator(model_id="invalid_model_id") @mock.patch("sagemaker.jumpstart.estimator.get_default_predictor") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1147,14 +1151,14 @@ def test_no_predictor_returns_default_predictor( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor mock_get_default_predictor.return_value = default_predictor_with_presets - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -1189,7 +1193,7 @@ def test_no_predictor_returns_default_predictor( self.assertEqual(predictor, default_predictor_with_presets) @mock.patch("sagemaker.jumpstart.estimator.get_default_predictor") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1206,14 +1210,14 @@ def test_no_predictor_yes_async_inference_config( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor mock_get_default_predictor.return_value = default_predictor_with_presets - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -1239,7 +1243,7 @@ def test_no_predictor_yes_async_inference_config( self.assertEqual(type(predictor), Predictor) @mock.patch("sagemaker.jumpstart.estimator.get_default_predictor") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1256,14 +1260,14 @@ def test_yes_predictor_returns_unmodified_predictor( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor mock_get_default_predictor.return_value = default_predictor_with_presets - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -1289,7 +1293,7 @@ def test_yes_predictor_returns_unmodified_predictor( self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.estimator._model_supports_incremental_training") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -1310,9 +1314,9 @@ def test_incremental_training_with_unsupported_model_logs_warning( mock_session_model: mock.Mock, mock_logger_warning: mock.Mock, mock_supports_incremental_training: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_estimator_deploy.return_value = default_predictor @@ -1343,7 +1347,7 @@ def test_incremental_training_with_unsupported_model_logs_warning( sagemaker_session=sagemaker_session, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.estimator._model_supports_incremental_training") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -1364,9 +1368,9 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( mock_session_model: mock.Mock, mock_logger_warning: mock.Mock, mock_supports_incremental_training: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_estimator_deploy.return_value = default_predictor @@ -1395,7 +1399,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1412,10 +1416,10 @@ def test_estimator_sets_different_inference_instance_depending_on_training_insta mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_sagemaker_timestamp.return_value = "3456" @@ -1456,7 +1460,7 @@ def test_estimator_sets_different_inference_instance_depending_on_training_insta assert mock_estimator_deploy.call_args[1]["instance_type"] == "ml.p4de.24xlarge" @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1473,10 +1477,10 @@ def test_training_passes_role_to_deploy( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_sagemaker_timestamp.return_value = "3456" @@ -1533,7 +1537,7 @@ def test_training_passes_role_to_deploy( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch( "sagemaker.jumpstart.factory.model.DEFAULT_JUMPSTART_SAGEMAKER_SESSION", sagemaker_session ) @@ -1553,10 +1557,10 @@ def test_training_passes_session_to_deploy( mock_estimator_fit: mock.Mock, mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_sagemaker_timestamp.return_value = "3456" @@ -1611,7 +1615,7 @@ def test_training_passes_session_to_deploy( ], ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.estimator._retrieve_estimator_init_kwargs") @@ -1627,11 +1631,11 @@ def test_model_id_not_found_refeshes_cache_training( mock_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.side_effect = [False, False] + mock_validate_model_id_and_get_type.side_effect = [False, False] model_id, _ = "js-trainable-model", "*" @@ -1647,7 +1651,7 @@ def test_model_id_not_found_refeshes_cache_training( ) mock_reset_cache.assert_called_once_with() - mock_is_valid_model_id.assert_has_calls( + mock_validate_model_id_and_get_type.assert_has_calls( calls=[ mock.call( model_id="js-trainable-model", @@ -1666,16 +1670,16 @@ def test_model_id_not_found_refeshes_cache_training( ] ) - mock_is_valid_model_id.reset_mock() + mock_validate_model_id_and_get_type.reset_mock() mock_reset_cache.reset_mock() - mock_is_valid_model_id.side_effect = [False, True] + mock_validate_model_id_and_get_type.side_effect = [False, True] JumpStartEstimator( model_id=model_id, ) mock_reset_cache.assert_called_once_with() - mock_is_valid_model_id.assert_has_calls( + mock_validate_model_id_and_get_type.assert_has_calls( calls=[ mock.call( model_id="js-trainable-model", @@ -1694,7 +1698,7 @@ def test_model_id_not_found_refeshes_cache_training( ] ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @@ -1704,10 +1708,10 @@ def test_model_artifact_variant_estimator( mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "model-artifact-variant-model", "*" diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py index d22e910a00..073921d5ba 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py @@ -25,6 +25,7 @@ from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.estimator import JumpStartEstimator +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session from sagemaker.utils import resolve_value_from_config @@ -78,7 +79,7 @@ def config_value_impl(sagemaker_session: Session, config_path: str, sagemaker_co class IntelligentDefaultsEstimatorTest(unittest.TestCase): - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -98,12 +99,12 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( mock_retrieve_model_init_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -135,7 +136,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( assert "enable_network_isolation" not in mock_estimator_deploy.call_args[1] - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -155,12 +156,12 @@ def test_without_arg_overwrites_with_kwarg_collisions_with_config( mock_model_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -208,7 +209,7 @@ def test_without_arg_overwrites_with_kwarg_collisions_with_config( config_inference_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -228,12 +229,12 @@ def test_with_arg_overwrites_with_kwarg_collisions_with_config( mock_model_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -290,7 +291,7 @@ def test_with_arg_overwrites_with_kwarg_collisions_with_config( override_inference_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -310,12 +311,12 @@ def test_with_arg_overwrites_without_kwarg_collisions_with_config( mock_model_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -365,7 +366,7 @@ def test_with_arg_overwrites_without_kwarg_collisions_with_config( override_inference_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.session.Session.get_caller_identity_arn") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @@ -387,12 +388,12 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, mock_get_caller_identity_arn: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -426,7 +427,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( assert "enable_network_isolation" not in mock_estimator_deploy.call_args[1] - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.session.Session.get_caller_identity_arn") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @@ -448,12 +449,12 @@ def test_without_arg_overwrites_with_kwarg_collisions_without_config( mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, mock_get_caller_identity_arn: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_get_caller_identity_arn.return_value = execution_role model_id, _ = "js-trainable-model", "*" @@ -500,7 +501,7 @@ def test_without_arg_overwrites_with_kwarg_collisions_without_config( metadata_inference_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -520,11 +521,11 @@ def test_with_arg_overwrites_with_kwarg_collisions_without_config( mock_model_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -576,7 +577,7 @@ def test_with_arg_overwrites_with_kwarg_collisions_without_config( override_inference_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.estimator._retrieve_estimator_init_kwargs") @@ -594,11 +595,11 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index f45283935b..ba4ba0bb13 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -22,7 +22,7 @@ _retrieve_default_environment_variables, ) from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag, JumpStartModelType from sagemaker.jumpstart.model import JumpStartModel from sagemaker.model import Model @@ -32,9 +32,11 @@ from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from tests.unit.sagemaker.jumpstart.utils import ( + get_spec_from_base_spec, get_special_model_spec, overwrite_dictionary, get_special_model_spec_for_inference_component_based_endpoint, + get_prototype_manifest, ) execution_role = "fake role! do not use!" @@ -53,7 +55,7 @@ class ModelTest(unittest.TestCase): @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER") @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -65,7 +67,7 @@ def test_non_prepacked( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, mock_jumpstart_model_factory_logger: mock.Mock, ): @@ -73,7 +75,7 @@ def test_non_prepacked( mock_sagemaker_timestamp.return_value = "7777" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -128,7 +130,7 @@ def test_non_prepacked( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -140,14 +142,14 @@ def test_non_prepacked_inference_component_based_endpoint( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): mock_model_deploy.return_value = default_predictor mock_sagemaker_timestamp.return_value = "7777" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" mock_get_model_specs.side_effect = ( @@ -208,7 +210,7 @@ def test_non_prepacked_inference_component_based_endpoint( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -220,14 +222,14 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): mock_model_deploy.return_value = default_predictor mock_sagemaker_timestamp.return_value = "7777" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-class-model-prepacked", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -282,7 +284,7 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -294,11 +296,11 @@ def test_prepacked( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-class-model-prepacked", "*" @@ -345,7 +347,7 @@ def test_prepacked( @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") @mock.patch("sagemaker.session.Session.create_model") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -353,7 +355,7 @@ def test_no_compiled_model_warning_log_js_models( self, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_create_model: mock.Mock, mock_endpoint_from_production_variants: mock.Mock, mock_timestamp: mock.Mock, @@ -362,7 +364,7 @@ def test_no_compiled_model_warning_log_js_models( mock_timestamp.return_value = "1234" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "gated_llama_neuron_model", "*" @@ -381,7 +383,7 @@ def test_no_compiled_model_warning_log_js_models( @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") @mock.patch("sagemaker.session.Session.create_model") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -389,15 +391,14 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( self, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_create_model: mock.Mock, mock_endpoint_from_production_variants: mock.Mock, mock_timestamp: mock.Mock, ): - mock_timestamp.return_value = "1234" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "gated_variant-model", "*" @@ -441,7 +442,64 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( ], ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.__init__") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_proprietary_model_endpoint( + self, + mock_model_deploy: mock.Mock, + mock_model_init: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, + mock_sagemaker_timestamp: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + mock_sagemaker_timestamp.return_value = "7777" + + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.PROPRIETARY + model_id, _ = "ai21-summarization", "*" + + mock_get_model_specs.side_effect = get_spec_from_base_spec + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id, model_version="2.0.004") + + mock_model_init.assert_called_once_with( + predictor_cls=Predictor, + role=execution_role, + sagemaker_session=sagemaker_session, + enable_network_isolation=False, + ) + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.p4de.24xlarge", + wait=True, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "ai21-summarization"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "2.0.004"}, + {"Key": JumpStartTag.MODEL_TYPE, "Value": "proprietary"}, + ], + endpoint_logging=False, + model_data_download_timeout=3600, + container_startup_health_check_timeout=600, + ) + + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -451,11 +509,11 @@ def test_deprecated( mock_model_deploy: mock.Mock, mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "deprecated_model", "*" @@ -468,7 +526,7 @@ def test_deprecated( JumpStartModel(model_id=model_id, tolerate_deprecated_model=True).deploy() - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -478,9 +536,9 @@ def test_vulnerable( mock_model_deploy: mock.Mock, mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_model_deploy.return_value = default_predictor @@ -543,7 +601,7 @@ def test_model_use_kwargs(self): ) @mock.patch("sagemaker.jumpstart.factory.model.environment_variables.retrieve_default") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -555,7 +613,7 @@ def evaluate_model_workflow_with_kwargs( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_retrieve_environment_variables: mock.Mock, init_kwargs: Optional[dict] = None, deploy_kwargs: Optional[dict] = None, @@ -565,7 +623,7 @@ def evaluate_model_workflow_with_kwargs( mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_session.return_value = sagemaker_session @@ -661,22 +719,22 @@ def test_jumpstart_model_kwargs_match_parent_class(self): @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") - def test_is_valid_model_id( + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") + def test_validate_model_id_and_get_type( self, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_init: mock.Mock, mock_get_init_kwargs: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS JumpStartModel(model_id="valid_model_id") - mock_is_valid_model_id.return_value = False + mock_validate_model_id_and_get_type.return_value = False with pytest.raises(ValueError): JumpStartModel(model_id="invalid_model_id") @mock.patch("sagemaker.jumpstart.model.get_default_predictor") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -688,14 +746,14 @@ def test_no_predictor_returns_default_predictor( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-class-model-prepacked", "*" @@ -717,12 +775,13 @@ def test_no_predictor_returns_default_predictor( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=model.sagemaker_session, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -734,14 +793,14 @@ def test_no_predictor_yes_async_inference_config( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-class-model-prepacked", "*" @@ -758,7 +817,7 @@ def test_no_predictor_yes_async_inference_config( mock_get_default_predictor.assert_not_called() @mock.patch("sagemaker.jumpstart.model.get_default_predictor") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -770,14 +829,14 @@ def test_yes_predictor_returns_default_predictor( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-class-model-prepacked", "*" @@ -793,24 +852,24 @@ def test_yes_predictor_returns_default_predictor( self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.JumpStartModelsAccessor.reset_cache") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) - def test_model_id_not_found_refeshes_cach_inference( + def test_model_id_not_found_refeshes_cache_inference( self, mock_reset_cache: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.side_effect = [False, False] + mock_validate_model_id_and_get_type.side_effect = [False, False] model_id, _ = "js-trainable-model", "*" @@ -826,7 +885,7 @@ def test_model_id_not_found_refeshes_cach_inference( ) mock_reset_cache.assert_called_once_with() - mock_is_valid_model_id.assert_has_calls( + mock_validate_model_id_and_get_type.assert_has_calls( calls=[ mock.call( model_id="js-trainable-model", @@ -845,16 +904,19 @@ def test_model_id_not_found_refeshes_cach_inference( ] ) - mock_is_valid_model_id.reset_mock() + mock_validate_model_id_and_get_type.reset_mock() mock_reset_cache.reset_mock() - mock_is_valid_model_id.side_effect = [False, True] + mock_validate_model_id_and_get_type.side_effect = [ + False, + JumpStartModelType.OPEN_WEIGHTS, + ] JumpStartModel( model_id=model_id, ) mock_reset_cache.assert_called_once_with() - mock_is_valid_model_id.assert_has_calls( + mock_validate_model_id_and_get_type.assert_has_calls( calls=[ mock.call( model_id="js-trainable-model", @@ -873,16 +935,16 @@ def test_model_id_not_found_refeshes_cach_inference( ] ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_tags( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "env-var-variant-model", "*" @@ -909,16 +971,16 @@ def test_jumpstart_model_tags( [{"Key": "blah", "Value": "blahagain"}] + js_tags, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_tags_disabled( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "env-var-variant-model", "*" @@ -941,16 +1003,16 @@ def test_jumpstart_model_tags_disabled( [{"Key": "blah", "Value": "blahagain"}], ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_package_arn( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-package-arn", "*" @@ -975,16 +1037,16 @@ def test_jumpstart_model_package_arn( self.assertIn(tag, mock_session.create_model.call_args[1]["tags"]) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_package_arn_override( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS # arbitrary model without model packarn arn model_id, _ = "js-trainable-model", "*" @@ -1017,7 +1079,7 @@ def test_jumpstart_model_package_arn_override( }, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1025,10 +1087,10 @@ def test_jumpstart_model_package_arn_unsupported_region( self, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-package-arn", "*" @@ -1044,7 +1106,7 @@ def test_jumpstart_model_package_arn_unsupported_region( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -1058,14 +1120,14 @@ def test_model_data_s3_prefix_override( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): mock_model_deploy.return_value = default_predictor mock_sagemaker_timestamp.return_value = "7777" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -1109,7 +1171,7 @@ def test_model_data_s3_prefix_override( '"S3DataType": "S3Prefix", "CompressionType": "None"}}', ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -1123,11 +1185,11 @@ def test_model_data_s3_prefix_model( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "model_data_s3_prefix_model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -1153,7 +1215,7 @@ def test_model_data_s3_prefix_model( mock_js_info_logger.assert_not_called() - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -1167,11 +1229,11 @@ def test_model_artifact_variant_model( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "model-artifact-variant-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -1218,7 +1280,7 @@ def test_model_artifact_variant_model( enable_network_isolation=True, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -1230,11 +1292,11 @@ def test_model_registry_accept_and_response_types( mock_model_deploy: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "model_data_s3_prefix_model", "*" mock_get_model_specs.side_effect = get_special_model_spec diff --git a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py index 727f3060b3..70409704e6 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py @@ -21,6 +21,7 @@ from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.model import JumpStartModel +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session from sagemaker.utils import resolve_value_from_config @@ -59,7 +60,7 @@ class IntelligentDefaultsModelTest(unittest.TestCase): region = "us-west-2" sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -75,10 +76,10 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" mock_retrieve_kwargs.return_value = {} @@ -100,7 +101,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( assert "enable_network_isolation" not in mock_model_init.call_args[1] - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -116,10 +117,10 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -146,7 +147,7 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( override_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -162,10 +163,10 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -192,7 +193,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( config_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -208,10 +209,10 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -240,7 +241,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( override_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -256,10 +257,10 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -286,7 +287,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( metadata_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -302,9 +303,9 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -333,7 +334,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( override_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -349,10 +350,10 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -374,7 +375,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( self.assertEquals(mock_model_init.call_args[1].get("role"), execution_role) assert "enable_network_isolation" not in mock_model_init.call_args[1] - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -390,10 +391,10 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index 97427be1ae..c57d2a958b 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -18,6 +18,7 @@ import pytest from sagemaker.jumpstart import accessors +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.constants import BASE_MANIFEST from tests.unit.sagemaker.jumpstart.utils import ( get_header_from_base_header, @@ -63,8 +64,51 @@ def test_jumpstart_models_cache_get_fxs(mock_cache): reload(accessors) +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") +def test_jumpstart_proprietary_models_cache_get(mock_cache): + + mock_cache.get_manifest = Mock(return_value=BASE_MANIFEST) + mock_cache.get_header = Mock(side_effect=get_header_from_base_header) + mock_cache.get_specs = Mock(side_effect=get_spec_from_base_spec) + + assert get_header_from_base_header( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) == accessors.JumpStartModelsAccessor.get_model_header( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + assert get_spec_from_base_spec( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) == accessors.JumpStartModelsAccessor.get_model_specs( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + assert ( + len( + accessors.JumpStartModelsAccessor._get_manifest( + model_type=JumpStartModelType.PROPRIETARY + ) + ) + > 0 + ) + + # necessary because accessors is a static module + reload(accessors) + + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache") -def test_jumpstart_models_cache_set_reset_fxs(mock_model_cache: Mock): +def test_jumpstart_models_cache_set_reset(mock_model_cache: Mock): # test change of region resets cache accessors.JumpStartModelsAccessor.get_model_header( @@ -138,6 +182,50 @@ def test_jumpstart_models_cache_set_reset_fxs(mock_model_cache: Mock): reload(accessors) +@patch("sagemaker.jumpstart.cache.JumpStartModelsCache") +def test_jumpstart_proprietary_models_cache_set_reset(mock_model_cache: Mock): + + # test change of region resets cache + accessors.JumpStartModelsAccessor.get_model_header( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + accessors.JumpStartModelsAccessor.get_model_specs( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + mock_model_cache.assert_called_once() + mock_model_cache.reset_mock() + + accessors.JumpStartModelsAccessor.get_model_header( + region="us-east-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + mock_model_cache.assert_called_once() + mock_model_cache.reset_mock() + + accessors.JumpStartModelsAccessor.get_model_specs( + region="us-west-1", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + mock_model_cache.assert_called_once() + mock_model_cache.reset_mock() + + # necessary because accessors is a static module + reload(accessors) + + class TestS3Accessor(TestCase): bucket = "bucket" diff --git a/tests/unit/sagemaker/jumpstart/test_artifacts.py b/tests/unit/sagemaker/jumpstart/test_artifacts.py index 1a770f785f..21112926a5 100644 --- a/tests/unit/sagemaker/jumpstart/test_artifacts.py +++ b/tests/unit/sagemaker/jumpstart/test_artifacts.py @@ -32,7 +32,7 @@ from sagemaker.jumpstart.artifacts.model_packages import _retrieve_model_package_arn from sagemaker.jumpstart.artifacts.model_uris import _retrieve_model_uri -from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec from tests.unit.sagemaker.workflow.conftest import mock_client @@ -331,9 +331,13 @@ class RetrieveModelPackageArnTest(unittest.TestCase): mock_session = Mock(s3_client=mock_client) + @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") - def test_retrieve_model_package_arn(self, patched_get_model_specs): + def test_retrieve_model_package_arn( + self, patched_get_model_specs: Mock, patched_validate_model_id_and_get_type: Mock + ): patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id = "variant-model" region = "us-west-2" @@ -437,9 +441,13 @@ class PrivateJumpStartBucketTest(unittest.TestCase): mock_session = Mock(s3_client=mock_client) + @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") - def test_retrieve_uri_from_gated_bucket(self, patched_get_model_specs): + def test_retrieve_uri_from_gated_bucket( + self, patched_get_model_specs, patched_validate_model_id_and_get_type + ): patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id = "private-model" region = "us-west-2" diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 6633ecdc23..50fe6da0a6 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -23,7 +23,11 @@ import pytest from mock import patch -from sagemaker.jumpstart.cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache +from sagemaker.jumpstart.cache import ( + JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, + JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, + JumpStartModelsCache, +) from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, @@ -32,7 +36,9 @@ JumpStartModelHeader, JumpStartModelSpecs, JumpStartVersionedModelId, + JumpStartS3FileType, ) +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import ( get_spec_from_base_spec, patched_retrieval_function, @@ -41,6 +47,8 @@ from tests.unit.sagemaker.jumpstart.constants import ( BASE_MANIFEST, BASE_SPEC, + BASE_PROPRIETARY_SPEC, + BASE_PROPRIETARY_MANIFEST, ) from sagemaker.jumpstart.utils import get_jumpstart_content_bucket @@ -152,6 +160,33 @@ def test_jumpstart_cache_get_header(): semantic_version_str="1.0.*", ) + assert JumpStartModelHeader( + { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], + } + ) == cache.get_header( + model_id="ai21-summarization", + semantic_version_str="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) + + with pytest.raises(KeyError) as e: + cache.get_header( + model_id="ai21-summarization", + semantic_version_str="3.*", + model_type=JumpStartModelType.PROPRIETARY, + ) + assert ( + "Proprietary model 'ai21-summarization' does not support wildcard version identifier '3.*'. " + "You can pin to version '1.1.003'. " + "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " + "for a list of valid model IDs. " in str(e.value) + ) + with pytest.raises(KeyError) as e: cache.get_header( model_id="tensorflow-ic-imagenet-inception-v3-classification-4", @@ -194,6 +229,19 @@ def test_jumpstart_cache_get_header(): "v3-classification-4'?" ) in str(e.value) + with pytest.raises(KeyError) as e: + cache.get_header( + model_id="ai21-summarize", + semantic_version_str="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) + assert ( + "Unable to find model manifest for 'ai21-summarize' with version '1.1.003'. " + "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " + "for updated list of models. " + "Did you mean to use model ID 'ai21-summarization'?" + ) in str(e.value) + with pytest.raises(KeyError): cache.get_header( model_id="tensorflow-ic-imagenet-inception-v3-classification-4", @@ -224,6 +272,27 @@ def test_jumpstart_cache_get_header(): semantic_version_str="*", ) + with pytest.raises(KeyError): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="1.1.004", + model_type=JumpStartModelType.PROPRIETARY, + ) + + with pytest.raises(KeyError): + cache.get_header( + model_id="ai21-summarization", + semantic_version_str="2.*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + with pytest.raises(KeyError): + cache.get_header( + model_id="ai21-summarization", + semantic_version_str="v*", + model_type=JumpStartModelType.PROPRIETARY, + ) + @patch("boto3.client") def test_jumpstart_cache_handles_boto3_issues(mock_boto3_client): @@ -276,6 +345,12 @@ def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client): cache.set_manifest_file_s3_key("some_key1") cache.clear.assert_called_once() + cache.clear.reset_mock() + cache.set_manifest_file_s3_key("some_key1", file_type=JumpStartS3FileType.OPEN_WEIGHT_MANIFEST) + cache.clear.assert_called_once() + with pytest.raises(ValueError): + cache.set_manifest_file_s3_key("some_key1", file_type="unknown_type") + def test_jumpstart_cache_handles_boto3_client_errors(): # Testing get_object @@ -423,15 +498,80 @@ def test_jumpstart_cache_accepts_input_parameters(): assert cache._s3_cache._max_cache_items == max_s3_cache_items assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon assert ( - cache._model_id_semantic_version_manifest_key_cache._max_cache_items + cache._open_weight_model_id_manifest_key_cache._max_cache_items == max_semantic_version_cache_items ) assert ( - cache._model_id_semantic_version_manifest_key_cache._expiration_horizon + cache._open_weight_model_id_manifest_key_cache._expiration_horizon == semantic_version_cache_expiration_horizon ) +def test_jumpstart_proprietary_cache_accepts_input_parameters(): + + region = "us-east-1" + max_s3_cache_items = 1 + s3_cache_expiration_horizon = datetime.timedelta(weeks=2) + max_semantic_version_cache_items = 3 + semantic_version_cache_expiration_horizon = datetime.timedelta(microseconds=4) + bucket = "my-amazing-bucket" + manifest_file_key = "some_s3_key" + proprietary_manifest_file_key = "some_proprietary_s3_key" + + cache = JumpStartModelsCache( + region=region, + max_s3_cache_items=max_s3_cache_items, + s3_cache_expiration_horizon=s3_cache_expiration_horizon, + max_semantic_version_cache_items=max_semantic_version_cache_items, + semantic_version_cache_expiration_horizon=semantic_version_cache_expiration_horizon, + s3_bucket_name=bucket, + manifest_file_s3_key=manifest_file_key, + proprietary_manifest_s3_key=proprietary_manifest_file_key, + ) + + assert ( + cache.get_manifest_file_s3_key(file_type=JumpStartS3FileType.PROPRIETARY_MANIFEST) + == proprietary_manifest_file_key + ) + assert cache.get_region() == region + assert cache.get_bucket() == bucket + assert cache._s3_cache._max_cache_items == max_s3_cache_items + assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon + assert ( + cache._proprietary_model_id_manifest_key_cache._max_cache_items + == max_semantic_version_cache_items + ) + assert ( + cache._proprietary_model_id_manifest_key_cache._expiration_horizon + == semantic_version_cache_expiration_horizon + ) + + +def test_jumpstart_cache_raise_unknown_file_type_exception(): + + region = "us-east-1" + max_s3_cache_items = 1 + s3_cache_expiration_horizon = datetime.timedelta(weeks=2) + max_semantic_version_cache_items = 3 + semantic_version_cache_expiration_horizon = datetime.timedelta(microseconds=4) + bucket = "my-amazing-bucket" + manifest_file_key = "some_s3_key" + proprietary_manifest_file_key = "some_proprietary_s3_key" + + cache = JumpStartModelsCache( + region=region, + max_s3_cache_items=max_s3_cache_items, + s3_cache_expiration_horizon=s3_cache_expiration_horizon, + max_semantic_version_cache_items=max_semantic_version_cache_items, + semantic_version_cache_expiration_horizon=semantic_version_cache_expiration_horizon, + s3_bucket_name=bucket, + manifest_file_s3_key=manifest_file_key, + proprietary_manifest_s3_key=proprietary_manifest_file_key, + ) + with pytest.raises(ValueError): + cache.get_manifest_file_s3_key(file_type="unknown_type") + + @patch("boto3.client") def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client): @@ -583,7 +723,7 @@ def test_jumpstart_cache_makes_correct_s3_calls( with patch("logging.Logger.warning") as mocked_warning_log: cache.get_specs( - model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + model_id="pytorch-ic-imagenet-inception-v3-classification-4", version_str="*" ) mocked_warning_log.assert_called_once_with( "Using model 'pytorch-ic-imagenet-inception-v3-classification-4' with wildcard " @@ -593,7 +733,7 @@ def test_jumpstart_cache_makes_correct_s3_calls( ) mocked_warning_log.reset_mock() cache.get_specs( - model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + model_id="pytorch-ic-imagenet-inception-v3-classification-4", version_str="*" ) mocked_warning_log.assert_not_called() @@ -605,13 +745,85 @@ def test_jumpstart_cache_makes_correct_s3_calls( mock_boto3_client.return_value.head_object.assert_not_called() +@patch("sagemaker.jumpstart.cache.utils.emit_logs_based_on_model_specs") +@patch("boto3.client") +def test_jumpstart_cache_proprietary_manifest_makes_correct_s3_calls( + mock_boto3_client, mock_emit_logs_based_on_model_specs +): + + # test get_header + mock_manifest_json = json.dumps( + [ + { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], + }, + ] + ) + mock_boto3_client.return_value.get_object.return_value = { + "Body": botocore.response.StreamingBody( + io.BytesIO(bytes(mock_manifest_json, "utf-8")), content_length=len(mock_manifest_json) + ), + "ETag": "etag", + } + + mock_boto3_client.return_value.head_object.return_value = {"ETag": "some-hash"} + + bucket_name = get_jumpstart_content_bucket("us-west-2") + client_config = botocore.config.Config(signature_version="my_signature_version") + cache = JumpStartModelsCache( + s3_bucket_name=bucket_name, s3_client_config=client_config, region="us-west-2" + ) + cache.get_header( + model_id="ai21-summarization", + semantic_version_str="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) + + mock_boto3_client.return_value.get_object.assert_called_with( + Bucket=bucket_name, Key=JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY + ) + mock_boto3_client.return_value.head_object.assert_not_called() + + mock_boto3_client.assert_called_with("s3", region_name="us-west-2", config=client_config) + + # test get_specs. manifest already in cache, so only s3 call will be to get specs. + mock_json = json.dumps(BASE_PROPRIETARY_SPEC) + + mock_boto3_client.return_value.reset_mock() + + mock_boto3_client.return_value.get_object.return_value = { + "Body": botocore.response.StreamingBody( + io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json) + ), + "ETag": "etag", + } + + with patch("logging.Logger.warning") as mocked_warning_log: + cache.get_specs( + model_id="ai21-summarization", + version_str="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) + mocked_warning_log.assert_not_called() + + mock_boto3_client.return_value.get_object.assert_called_with( + Bucket=bucket_name, + Key="proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + ) + mock_boto3_client.return_value.head_object.assert_not_called() + + @patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") cache.clear = MagicMock() - cache._model_id_semantic_version_manifest_key_cache = MagicMock() - cache._model_id_semantic_version_manifest_key_cache.get.side_effect = [ + cache._open_weight_model_id_manifest_key_cache = MagicMock() + cache._open_weight_model_id_manifest_key_cache.get.side_effect = [ ( JumpStartVersionedModelId( "tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0" @@ -640,7 +852,7 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): cache.clear.assert_called_once() cache.clear.reset_mock() - cache._model_id_semantic_version_manifest_key_cache.get.side_effect = [ + cache._open_weight_model_id_manifest_key_cache.get.side_effect = [ ( JumpStartVersionedModelId( "tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0" @@ -668,7 +880,18 @@ def test_jumpstart_get_full_manifest(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") raw_manifest = [header.to_json() for header in cache.get_manifest()] - raw_manifest == BASE_MANIFEST + assert raw_manifest == BASE_MANIFEST + + +@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) +@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") +def test_jumpstart_get_full_proprietary_manifest(): + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + raw_manifest = [ + header.to_json() for header in cache.get_manifest(model_type=JumpStartModelType.PROPRIETARY) + ] + + assert raw_manifest == BASE_PROPRIETARY_MANIFEST @patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) @@ -678,54 +901,92 @@ def test_jumpstart_cache_get_specs(): model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" assert get_spec_from_base_spec(model_id=model_id, version=version) == cache.get_specs( - model_id=model_id, semantic_version_str=version + model_id=model_id, version_str=version ) model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" assert get_spec_from_base_spec(model_id=model_id, version="2.0.0") == cache.get_specs( - model_id=model_id, semantic_version_str="2.0.*" + model_id=model_id, version_str="2.0.*" ) model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "1.0.0" assert get_spec_from_base_spec(model_id=model_id, version=version) == cache.get_specs( - model_id=model_id, semantic_version_str=version + model_id=model_id, version_str=version ) model_id = "pytorch-ic-imagenet-inception-v3-classification-4" assert get_spec_from_base_spec(model_id=model_id, version="1.0.0") == cache.get_specs( - model_id=model_id, semantic_version_str="1.*" + model_id=model_id, version_str="1.*" ) model_id = "pytorch-ic-imagenet-inception-v3-classification-4" assert get_spec_from_base_spec(model_id=model_id, version="1.0.0") == cache.get_specs( - model_id=model_id, semantic_version_str="1.0.*" + model_id=model_id, version_str="1.0.*" + ) + + assert get_spec_from_base_spec( + model_id="ai21-summarization", + version="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) == cache.get_specs( + model_id="ai21-summarization", + version_str="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) + + with pytest.raises(KeyError) as e: + cache.get_specs( + model_id="ai21-summarization", + version_str="3.*", + model_type=JumpStartModelType.PROPRIETARY, + ) + assert ( + "Proprietary model 'ai21-summarization' does not support wildcard version identifier '3.*'. " + "You can pin to version '1.1.003'. " + "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " + "for a list of valid model IDs. " in str(e.value) ) with pytest.raises(KeyError): - cache.get_specs(model_id=model_id + "bak", semantic_version_str="*") + cache.get_specs(model_id=model_id + "bak", version_str="*") with pytest.raises(KeyError): - cache.get_specs(model_id=model_id, semantic_version_str="9.*") + cache.get_specs(model_id=model_id, version_str="9.*") with pytest.raises(KeyError): - cache.get_specs(model_id=model_id, semantic_version_str="BAD") + cache.get_specs(model_id=model_id, version_str="BAD") with pytest.raises(KeyError): cache.get_specs( model_id=model_id, - semantic_version_str="2.1.*", + version_str="2.1.*", ) with pytest.raises(KeyError): cache.get_specs( model_id=model_id, - semantic_version_str="3.9.*", + version_str="3.9.*", ) with pytest.raises(KeyError): cache.get_specs( model_id=model_id, - semantic_version_str="5.*", + version_str="5.*", + ) + + model_id, version = "ai21-summarization", "2.0.0" + with pytest.raises(KeyError): + cache.get_specs( + model_id=model_id, + version_str="BAD", + model_type=JumpStartModelType.PROPRIETARY, + ) + + with pytest.raises(KeyError): + cache.get_specs( + model_id=model_id, + version_str="9.*", + model_type=JumpStartModelType.PROPRIETARY, ) @@ -794,9 +1055,7 @@ def test_jumpstart_local_metadata_override_specs( cache = JumpStartModelsCache(s3_bucket_name="some_bucket") model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" - assert JumpStartModelSpecs(BASE_SPEC) == cache.get_specs( - model_id=model_id, semantic_version_str=version - ) + assert JumpStartModelSpecs(BASE_SPEC) == cache.get_specs(model_id=model_id, version_str=version) mocked_is_dir.assert_any_call("/some/directory/metadata/specs/root") mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") @@ -840,7 +1099,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( cache = JumpStartModelsCache(s3_bucket_name="some_bucket") assert get_spec_from_base_spec(model_id=model_id, version=version) == cache.get_specs( - model_id=model_id, semantic_version_str=version + model_id=model_id, version_str=version ) mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") diff --git a/tests/unit/sagemaker/jumpstart/test_exceptions.py b/tests/unit/sagemaker/jumpstart/test_exceptions.py index 555099a753..2307d22474 100644 --- a/tests/unit/sagemaker/jumpstart/test_exceptions.py +++ b/tests/unit/sagemaker/jumpstart/test_exceptions.py @@ -11,10 +11,15 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +import pytest + +from botocore.exceptions import ClientError from sagemaker.jumpstart.exceptions import ( get_wildcard_model_version_msg, get_old_model_version_msg, + get_proprietary_model_subscription_error, + MarketplaceModelSubscriptionError, ) @@ -35,3 +40,32 @@ def test_get_old_model_version_msg(): "Note that models may have different input/output signatures after a major " "version upgrade." == get_old_model_version_msg("mother_of_all_models", "1.0.0", "1.2.3") ) + + +def test_get_marketplace_subscription_error(): + error = ClientError( + error_response={ + "Error": { + "Code": "ValidationException", + "Message": "Caller is not subscribed to the Marketplace listing.", + }, + }, + operation_name="mock-operation", + ) + with pytest.raises(MarketplaceModelSubscriptionError): + get_proprietary_model_subscription_error(error, subscription_link="mock-link") + + error = ClientError( + error_response={ + "Error": { + "Code": "UnknownException", + "Message": "Unknown error raised.", + }, + }, + operation_name="mock-operation", + ) + + try: + get_proprietary_model_subscription_error(error, subscription_link="mock-link") + except MarketplaceModelSubscriptionError: + pytest.fail("MarketplaceModelSubscriptionError should not be raised for unknown error.") diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 1a7108579c..059cd7ccad 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -17,6 +17,7 @@ get_prototype_manifest, get_prototype_model_spec, ) +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.notebook_utils import ( _generate_jumpstart_model_versions, get_model_url, @@ -63,7 +64,7 @@ def test_list_jumpstart_scripts( patched_generate_jumpstart_models.assert_called_once_with( **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION ) - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 assert patched_get_model_specs.call_count == 1 patched_get_model_specs.reset_mock() @@ -78,8 +79,8 @@ def test_list_jumpstart_scripts( patched_generate_jumpstart_models.assert_called_once_with( **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION ) - patched_get_manifest.assert_called_once() - assert patched_read_s3_file.call_count == len(PROTOTYPICAL_MODEL_SPECS_DICT) + assert patched_get_manifest.call_count == 2 + assert patched_read_s3_file.call_count == 2 * len(PROTOTYPICAL_MODEL_SPECS_DICT) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @@ -107,7 +108,7 @@ def test_list_jumpstart_tasks( ) # incomplete list, based on mocked metadata patched_generate_jumpstart_models.assert_called_once() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_model_specs.assert_not_called() patched_get_model_specs.reset_mock() @@ -122,7 +123,7 @@ def test_list_jumpstart_tasks( patched_generate_jumpstart_models.assert_called_once_with( **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION ) - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_model_specs.assert_not_called() @@ -154,11 +155,11 @@ def test_list_jumpstart_frameworks( ) patched_generate_jumpstart_models.assert_called_once() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_model_specs.assert_not_called() patched_get_model_specs.reset_mock() - patched_get_manifest.reset_mock() + assert patched_get_manifest.call_count == 2 patched_generate_jumpstart_models.reset_mock() kwargs = { @@ -180,7 +181,7 @@ def test_list_jumpstart_frameworks( patched_generate_jumpstart_models.assert_called_once_with( **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION ) - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 4 patched_get_model_specs.assert_not_called() @@ -238,7 +239,7 @@ def test_list_jumpstart_models_script_filter( kwargs = {"filter": f"training_supported != {val}"} list_jumpstart_models(**kwargs) assert patched_read_s3_file.call_count == manifest_length - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -255,7 +256,7 @@ def test_list_jumpstart_models_script_filter( ("xgboost-classification-model", "1.0.0"), ] assert patched_read_s3_file.call_count == manifest_length - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -264,7 +265,7 @@ def test_list_jumpstart_models_script_filter( models = list_jumpstart_models(**kwargs) assert [] == models assert patched_read_s3_file.call_count == manifest_length - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -287,7 +288,7 @@ def test_list_jumpstart_models_task_filter( kwargs = {"filter": f"task == {val}"} list_jumpstart_models(**kwargs) patched_get_model_specs.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() @@ -295,7 +296,7 @@ def test_list_jumpstart_models_task_filter( kwargs = {"filter": f"task != {val}"} list_jumpstart_models(**kwargs) patched_get_model_specs.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() @@ -312,7 +313,7 @@ def test_list_jumpstart_models_task_filter( ("xgboost-classification-model", "1.0.0"), ] patched_get_model_specs.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() @@ -321,7 +322,7 @@ def test_list_jumpstart_models_task_filter( models = list_jumpstart_models(**kwargs) assert [] == models patched_get_model_specs.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") @@ -348,7 +349,7 @@ def test_list_jumpstart_models_framework_filter( kwargs = {"filter": f"framework == {val}"} list_jumpstart_models(**kwargs) patched_read_s3_file.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -356,7 +357,7 @@ def test_list_jumpstart_models_framework_filter( kwargs = {"filter": f"framework != {val}"} list_jumpstart_models(**kwargs) patched_read_s3_file.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -372,7 +373,7 @@ def test_list_jumpstart_models_framework_filter( ("xgboost-classification-model", "1.0.0"), ] patched_read_s3_file.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -390,8 +391,8 @@ def test_list_jumpstart_models_framework_filter( } models = list_jumpstart_models(**kwargs) assert [("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0")] == models - patched_read_s3_file.assert_called_once() - patched_get_manifest.assert_called_once() + assert patched_read_s3_file.call_count == 2 + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -405,7 +406,7 @@ def test_list_jumpstart_models_framework_filter( models = list_jumpstart_models(**kwargs) assert [] == models patched_read_s3_file.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -420,8 +421,10 @@ def test_list_jumpstart_models_region( list_jumpstart_models(region="some-region") - patched_get_manifest.assert_called_once_with( - region="some-region", s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client + patched_get_manifest.assert_called_with( + region="some-region", + s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @@ -478,7 +481,7 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME): ] == list_jumpstart_models(list_old_models=True, list_versions=True) patched_get_model_specs.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() @@ -526,8 +529,8 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): And("inference_vulnerable is false", "training_vulnerable is false") ) - assert patched_read_s3_file.call_count == num_specs - patched_get_manifest.assert_called_once() + assert patched_read_s3_file.call_count == 2 * num_specs + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -538,8 +541,8 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): And("inference_vulnerable is false", "training_vulnerable is false") ) - assert patched_read_s3_file.call_count == num_specs - patched_get_manifest.assert_called_once() + assert patched_read_s3_file.call_count == 2 * num_specs + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -570,8 +573,8 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str: num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) assert [] == list_jumpstart_models("deprecated equals false") - assert patched_read_s3_file.call_count == num_specs - patched_get_manifest.assert_called_once() + assert patched_read_s3_file.call_count == 2 * num_specs + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -607,6 +610,43 @@ def test_list_jumpstart_models_no_versions( assert list_jumpstart_models(list_versions=False) == all_model_ids + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + def test_list_jumpstart_proprietary_models( + self, + patched_get_model_specs: Mock, + patched_get_manifest: Mock, + ): + patched_get_model_specs.side_effect = get_prototype_model_spec + patched_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + + all_prop_model_ids = [ + "ai21-paraphrase", + "ai21-summarization", + "lighton-mini-instruct40b", + ] + + all_open_weight_model_ids = [ + "catboost-classification-model", + "huggingface-spc-bert-base-cased", + "lightgbm-classification-model", + "mxnet-semseg-fcn-resnet50-ade", + "pytorch-eqa-bert-base-cased", + "sklearn-classification-linear", + "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", + "xgboost-classification-model", + ] + + assert list_jumpstart_models("model_type == proprietary") == all_prop_model_ids + assert list_jumpstart_models("model_type == marketplace") == all_prop_model_ids + assert list_jumpstart_models("model_type == open_weights") == all_open_weight_model_ids + + assert list_jumpstart_models(list_versions=False) == sorted( + all_prop_model_ids + all_open_weight_model_ids + ) + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") def test_list_jumpstart_models_complex_queries( @@ -670,12 +710,20 @@ def test_list_jumpstart_models_multiple_level_index( list_jumpstart_models("hosting_ecr_specs.py_version == py3") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_get_model_url( patched_get_model_specs: Mock, + patched_validate_model_id_and_get_type: Mock, + patched_get_manifest: Mock, ): patched_get_model_specs.side_effect = get_prototype_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS + patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( + region + ) model_id, version = "xgboost-classification-model", "1.0.0" assert "https://xgboost.readthedocs.io/en/latest/" == get_model_url(model_id, version) @@ -686,7 +734,6 @@ def test_get_model_url( ) model_id, version = "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0" - region = "fake-region" patched_get_model_specs.reset_mock() patched_get_model_specs.side_effect = lambda *largs, **kwargs: get_prototype_model_spec( @@ -695,11 +742,12 @@ def test_get_model_url( **{key: value for key, value in kwargs.items() if key != "region"}, ) - get_model_url(model_id, version, region=region) + get_model_url(model_id, version, region="us-west-2") patched_get_model_specs.assert_called_once_with( model_id=model_id, version=version, - region=region, + region="us-west-2", s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 7ab9cdd1cc..52f28f2da1 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -7,17 +7,15 @@ import pytest from sagemaker.deserializers import JSONDeserializer from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.enums import MIMEType +from sagemaker.jumpstart.enums import MIMEType, JumpStartModelType from sagemaker import predictor from sagemaker.jumpstart.model import JumpStartModel from sagemaker.jumpstart.utils import verify_model_region_and_return_specs -from sagemaker.serializers import IdentitySerializer -from tests.unit.sagemaker.jumpstart.utils import ( - get_special_model_spec, -) +from sagemaker.serializers import IdentitySerializer, JSONSerializer +from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec, get_spec_from_base_spec @patch("sagemaker.predictor.get_model_id_version_from_endpoint") @@ -54,6 +52,43 @@ def test_jumpstart_predictor_support( assert js_predictor.accept == MIMEType.JSON +@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_proprietary_predictor_support( + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_get_jumpstart_model_id_version_from_endpoint, +): + + patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs + patched_get_model_specs.side_effect = get_spec_from_base_spec + + # version not needed for JumpStart predictor + model_id, model_version = "ai21-summarization", "*" + + patched_get_jumpstart_model_id_version_from_endpoint.return_value = ( + model_id, + model_version, + None, + ) + + js_predictor = predictor.retrieve_default( + endpoint_name="blah", + model_id=model_id, + model_version=model_version, + model_type=JumpStartModelType.PROPRIETARY, + ) + + patched_get_jumpstart_model_id_version_from_endpoint.assert_not_called() + + assert js_predictor.content_type == MIMEType.JSON + assert isinstance(js_predictor.serializer, JSONSerializer) + + assert isinstance(js_predictor.deserializer, JSONDeserializer) + assert js_predictor.accept == MIMEType.JSON + + @patch("sagemaker.predictor.Predictor") @patch("sagemaker.predictor.get_default_predictor") @patch("sagemaker.predictor.get_model_id_version_from_endpoint") @@ -92,6 +127,7 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=mock_session, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) @@ -125,19 +161,19 @@ def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( @patch("sagemaker.predictor.get_model_id_version_from_endpoint") @patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached") -@patch("sagemaker.jumpstart.model.is_valid_model_id") +@patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_serializable_payload_with_predictor( patched_get_model_specs, patched_verify_model_region_and_return_specs, - patched_is_valid_model_id, + patched_validate_model_id_and_get_type, patched_get_object_cached, patched_get_model_id_version_from_endpoint, ): patched_get_object_cached.return_value = base64.b64decode("encodedimage") - patched_is_valid_model_id.return_value = True + patched_validate_model_id_and_get_type.return_value = True patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 556b99bc9c..3ec6f8aec3 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -32,7 +32,7 @@ JumpStartScriptScope, ) from functools import partial -from sagemaker.jumpstart.enums import JumpStartTag, MIMEType +from sagemaker.jumpstart.enums import JumpStartTag, MIMEType, JumpStartModelType from sagemaker.jumpstart.exceptions import ( DeprecatedJumpStartModelError, VulnerableJumpStartModelError, @@ -68,7 +68,7 @@ def test_get_jumpstart_content_bucket_override(): with patch("logging.Logger.info") as mocked_info_log: random_region = "random_region" assert "some-val" == utils.get_jumpstart_content_bucket(random_region) - mocked_info_log.assert_called_once_with("Using JumpStart bucket override: 'some-val'") + mocked_info_log.assert_called_with("Using JumpStart bucket override: 'some-val'") def test_get_jumpstart_gated_content_bucket(): @@ -1180,7 +1180,7 @@ def test_mime_type_enum_from_str(): class TestIsValidModelId(TestCase): @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") - def test_is_valid_model_id_true( + def test_validate_model_id_and_get_type_true( self, mock_get_model_specs: Mock, mock_get_manifest: Mock, @@ -1194,12 +1194,16 @@ def test_is_valid_model_id_true( mock_session_value = DEFAULT_JUMPSTART_SAGEMAKER_SESSION mock_s3_client_value = mock_session_value.s3_client - patched = partial(utils.is_valid_model_id, sagemaker_session=mock_session_value) + patched = partial( + utils.validate_model_id_and_get_type, sagemaker_session=mock_session_value + ) - with patch("sagemaker.jumpstart.utils.is_valid_model_id", patched): - self.assertTrue(utils.is_valid_model_id("bee")) - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + with patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type", patched): + self.assertTrue(utils.validate_model_id_and_get_type("bee")) + mock_get_manifest.assert_called_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + s3_client=mock_s3_client_value, + model_type=JumpStartModelType.PROPRIETARY, ) mock_get_model_specs.assert_not_called() @@ -1213,14 +1217,20 @@ def test_is_valid_model_id_true( ] mock_get_model_specs.return_value = Mock(training_supported=True) - self.assertTrue(utils.is_valid_model_id("bee", script=JumpStartScriptScope.TRAINING)) - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + self.assertTrue( + utils.validate_model_id_and_get_type("bee", script=JumpStartScriptScope.TRAINING) + ) + mock_get_manifest.assert_called_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + s3_client=mock_s3_client_value, + model_type=JumpStartModelType.PROPRIETARY, ) @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") - def test_is_valid_model_id_false(self, mock_get_model_specs: Mock, mock_get_manifest: Mock): + def test_validate_model_id_and_get_type_false( + self, mock_get_model_specs: Mock, mock_get_manifest: Mock + ): mock_get_manifest.return_value = [ Mock(model_id="ay"), Mock(model_id="bee"), @@ -1230,18 +1240,18 @@ def test_is_valid_model_id_false(self, mock_get_model_specs: Mock, mock_get_mani mock_session_value = DEFAULT_JUMPSTART_SAGEMAKER_SESSION mock_s3_client_value = mock_session_value.s3_client - patched = partial(utils.is_valid_model_id, sagemaker_session=mock_session_value) + patched = partial( + utils.validate_model_id_and_get_type, sagemaker_session=mock_session_value + ) - with patch("sagemaker.jumpstart.utils.is_valid_model_id", patched): + with patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type", patched): - self.assertFalse(utils.is_valid_model_id("dee")) - self.assertFalse(utils.is_valid_model_id("")) - self.assertFalse(utils.is_valid_model_id(None)) - self.assertFalse(utils.is_valid_model_id(set())) + self.assertFalse(utils.validate_model_id_and_get_type("dee")) + self.assertFalse(utils.validate_model_id_and_get_type("")) + self.assertFalse(utils.validate_model_id_and_get_type(None)) + self.assertFalse(utils.validate_model_id_and_get_type(set())) - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value - ) + mock_get_manifest.assert_called() mock_get_model_specs.assert_not_called() @@ -1253,30 +1263,48 @@ def test_is_valid_model_id_false(self, mock_get_model_specs: Mock, mock_get_mani Mock(model_id="bee"), Mock(model_id="see"), ] - self.assertFalse(utils.is_valid_model_id("dee", script=JumpStartScriptScope.TRAINING)) - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + self.assertFalse( + utils.validate_model_id_and_get_type("dee", script=JumpStartScriptScope.TRAINING) + ) + mock_get_manifest.assert_called_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + s3_client=mock_s3_client_value, + model_type=JumpStartModelType.PROPRIETARY, ) mock_get_manifest.reset_mock() - self.assertFalse(utils.is_valid_model_id("dee", script=JumpStartScriptScope.TRAINING)) - self.assertFalse(utils.is_valid_model_id("", script=JumpStartScriptScope.TRAINING)) - self.assertFalse(utils.is_valid_model_id(None, script=JumpStartScriptScope.TRAINING)) - self.assertFalse(utils.is_valid_model_id(set(), script=JumpStartScriptScope.TRAINING)) + self.assertFalse( + utils.validate_model_id_and_get_type("dee", script=JumpStartScriptScope.TRAINING) + ) + self.assertFalse( + utils.validate_model_id_and_get_type("", script=JumpStartScriptScope.TRAINING) + ) + self.assertFalse( + utils.validate_model_id_and_get_type(None, script=JumpStartScriptScope.TRAINING) + ) + self.assertFalse( + utils.validate_model_id_and_get_type(set(), script=JumpStartScriptScope.TRAINING) + ) mock_get_model_specs.assert_not_called() - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + mock_get_manifest.assert_called_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + s3_client=mock_s3_client_value, + model_type=JumpStartModelType.PROPRIETARY, ) mock_get_manifest.reset_mock() mock_get_model_specs.reset_mock() mock_get_model_specs.return_value = Mock(training_supported=False) - self.assertTrue(utils.is_valid_model_id("ay", script=JumpStartScriptScope.TRAINING)) - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + self.assertTrue( + utils.validate_model_id_and_get_type("ay", script=JumpStartScriptScope.TRAINING) + ) + mock_get_manifest.assert_called_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + s3_client=mock_s3_client_value, + model_type=JumpStartModelType.PROPRIETARY, ) diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 146c6fd1f7..65fe10f7a7 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -28,12 +28,16 @@ JumpStartS3FileType, JumpStartModelHeader, ) +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.utils import get_formatted_manifest from tests.unit.sagemaker.jumpstart.constants import ( PROTOTYPICAL_MODEL_SPECS_DICT, BASE_MANIFEST, BASE_SPEC, + BASE_PROPRIETARY_MANIFEST, + BASE_PROPRIETARY_SPEC, BASE_HEADER, + BASE_PROPRIETARY_HEADER, SPECIAL_MODEL_SPECS_DICT, ) @@ -44,11 +48,16 @@ def get_header_from_base_header( model_id: str = None, semantic_version_str: str = None, version: str = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelHeader: if version and semantic_version_str: raise ValueError("Cannot specify both `version` and `semantic_version_str` fields.") + if model_type == JumpStartModelType.PROPRIETARY: + spec = copy.deepcopy(BASE_PROPRIETARY_HEADER) + return JumpStartModelHeader(spec) + if all( [ "pytorch" not in model_id, @@ -79,7 +88,10 @@ def get_header_from_base_header( def get_prototype_manifest( region: str = JUMPSTART_DEFAULT_REGION_NAME, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[JumpStartModelHeader]: + if model_type == JumpStartModelType.PROPRIETARY: + return [JumpStartModelHeader(spec) for spec in BASE_PROPRIETARY_MANIFEST] return [ get_header_from_base_header(region=region, model_id=model_id, version=version) for model_id in PROTOTYPICAL_MODEL_SPECS_DICT.keys() @@ -92,6 +104,7 @@ def get_prototype_model_spec( model_id: str = None, version: str = None, s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. @@ -107,6 +120,7 @@ def get_special_model_spec( model_id: str = None, version: str = None, s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. This is reserved @@ -122,6 +136,7 @@ def get_special_model_spec_for_inference_component_based_endpoint( model_id: str = None, version: str = None, s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID and adding @@ -142,14 +157,22 @@ def get_spec_from_base_spec( _obj: JumpStartModelsCache = None, region: str = None, model_id: str = None, - semantic_version_str: str = None, + version_str: str = None, version: str = None, s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: - if version and semantic_version_str: + if version and version_str: raise ValueError("Cannot specify both `version` and `semantic_version_str` fields.") + if model_type == JumpStartModelType.PROPRIETARY: + spec = copy.deepcopy(BASE_PROPRIETARY_SPEC) + spec["version"] = version or version_str + spec["model_id"] = model_id + + return JumpStartModelSpecs(spec) + if all( [ "pytorch" not in model_id, @@ -172,7 +195,7 @@ def get_spec_from_base_spec( spec = copy.deepcopy(BASE_SPEC) - spec["version"] = version or semantic_version_str + spec["version"] = version or version_str spec["model_id"] = model_id return JumpStartModelSpecs(spec) @@ -185,19 +208,35 @@ def patched_retrieval_function( ) -> JumpStartCachedS3ContentValue: filetype, s3_key = key.file_type, key.s3_key - if filetype == JumpStartS3FileType.MANIFEST: + if filetype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: return JumpStartCachedS3ContentValue( formatted_content=get_formatted_manifest(BASE_MANIFEST) ) - if filetype == JumpStartS3FileType.SPECS: + if filetype == JumpStartS3FileType.OPEN_WEIGHT_SPECS: _, model_id, specs_version = s3_key.split("/") version = specs_version.replace("specs_v", "").replace(".json", "") return JumpStartCachedS3ContentValue( formatted_content=get_spec_from_base_spec(model_id=model_id, version=version) ) + if filetype == JumpStartS3FileType.PROPRIETARY_MANIFEST: + return JumpStartCachedS3ContentValue( + formatted_content=get_formatted_manifest(BASE_PROPRIETARY_MANIFEST) + ) + + if filetype == JumpStartS3FileType.PROPRIETARY_SPECS: + _, model_id, specs_version = s3_key.split("/") + version = specs_version.replace("proprietary_specs_", "").replace(".json", "") + return JumpStartCachedS3ContentValue( + formatted_content=get_spec_from_base_spec( + model_id=model_id, + version=version, + model_type=JumpStartModelType.PROPRIETARY, + ) + ) + raise ValueError(f"Bad value for filetype: {filetype}") diff --git a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py index ffc6000c91..608a32a005 100644 --- a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py +++ b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py @@ -18,6 +18,7 @@ import pytest from sagemaker import metric_definitions +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec @@ -25,10 +26,14 @@ mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_default_metric_definitions(patched_get_model_specs): +def test_jumpstart_default_metric_definitions( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -47,7 +52,11 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="*", s3_client=mock_client + region=region, + model_id=model_id, + version="*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -63,7 +72,11 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="1.*", s3_client=mock_client + region=region, + model_id=model_id, + version="1.*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 000540e12e..8d75731b06 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -19,19 +19,24 @@ from sagemaker import model_uris from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.model_uris.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_common_model_uri( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -47,6 +52,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -64,6 +70,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -82,6 +89,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -100,6 +108,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index b0cef0e3d4..7b5e7a598d 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -18,6 +18,7 @@ import pytest from sagemaker import resource_requirements +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker.jumpstart.artifacts.resource_requirements import ( REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP, @@ -26,10 +27,14 @@ from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_resource_requirements(patched_get_model_specs): +def test_jumpstart_resource_requirements( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS region = "us-west-2" mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -50,6 +55,7 @@ def test_jumpstart_resource_requirements(patched_get_model_specs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -103,9 +109,14 @@ def test_jumpstart_resource_requirements_instance_type_variants(patched_get_mode } +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs): +def test_jumpstart_no_supported_resource_requirements( + patched_get_model_specs, patched_validate_model_id_and_get_type +): + patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "no-supported-instance-types-model", "*" region = "us-west-2" @@ -126,6 +137,7 @@ def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 3f38326608..c797ba3559 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -19,19 +19,24 @@ from sagemaker import script_uris from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.script_uris.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_common_script_uri( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -47,6 +52,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -64,6 +70,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -78,7 +85,11 @@ def test_jumpstart_common_script_uri( sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client + region="us-west-2", + model_id="pytorch-ic-mobilenet-v2", + version="*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -97,6 +108,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py index b22b61dc40..c2253726bf 100644 --- a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py +++ b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py @@ -19,18 +19,23 @@ from sagemaker import base_serializers, serializers from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_serializers( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -50,19 +55,24 @@ def test_jumpstart_default_serializers( model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_serializer_options( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -89,4 +99,5 @@ def test_jumpstart_serializer_options( model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) From 064378d76a239317cbc06d9e680a0580b85e11e4 Mon Sep 17 00:00:00 2001 From: evakravi <69981223+evakravi@users.noreply.github.com> Date: Tue, 12 Mar 2024 20:12:10 -0400 Subject: [PATCH 07/20] chore: emit warning when no instance specific gated training env var is available, and raise exception when accept_eula flag is not supplied (#4485) * fix: raise exception when no instance specific gated training env var available * chore: raise client exception if accept_eula flag is not set for gated models * chore: address flake8 errors * chore: emit warning when instance type is chosen with no gated training artifacts --- .../artifacts/environment_variables.py | 34 +- src/sagemaker/jumpstart/factory/estimator.py | 21 + src/sagemaker/jumpstart/types.py | 4 + src/sagemaker/jumpstart/utils.py | 22 +- .../jumpstart/test_default.py | 65 + tests/unit/sagemaker/jumpstart/constants.py | 1236 +++++++++++++++++ .../jumpstart/estimator/test_estimator.py | 46 +- tests/unit/sagemaker/jumpstart/test_utils.py | 10 +- 8 files changed, 1388 insertions(+), 50 deletions(-) diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index 0e666e4c14..006559852c 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -12,10 +12,11 @@ # language governing permissions and limitations under the License. """This module contains functions for obtaining JumpStart environment variables.""" from __future__ import absolute_import -from typing import Dict, Optional +from typing import Callable, Dict, Optional, Set from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, + JUMPSTART_LOGGER, SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, ) from sagemaker.jumpstart.enums import ( @@ -110,7 +111,9 @@ def _retrieve_default_environment_variables( default_environment_variables.update(instance_specific_environment_variables) - gated_model_env_var: Optional[str] = _retrieve_gated_model_uri_env_var_value( + retrieve_gated_env_var_for_instance_type: Callable[ + [str], Optional[str] + ] = lambda instance_type: _retrieve_gated_model_uri_env_var_value( model_id=model_id, model_version=model_version, region=region, @@ -120,6 +123,33 @@ def _retrieve_default_environment_variables( instance_type=instance_type, ) + gated_model_env_var: Optional[str] = retrieve_gated_env_var_for_instance_type( + instance_type + ) + + if gated_model_env_var is None and model_specs.is_gated_model(): + + possible_env_vars: Set[str] = { + retrieve_gated_env_var_for_instance_type(instance_type) + for instance_type in model_specs.supported_training_instance_types + } + + # If all officially supported instance types have the same underlying artifact, + # we can use this artifact with high confidence that it'll succeed with + # an arbitrary instance. + if len(possible_env_vars) == 1: + gated_model_env_var = list(possible_env_vars)[0] + + # If this model does not have 1 artifact for all supported instance types, + # we cannot determine which artifact to use for an arbitrary instance. + else: + log_msg = ( + f"'{model_id}' does not support {instance_type} instance type" + " for training. Please use one of the following instance types: " + f"{', '.join(model_specs.supported_training_instance_types)}." + ) + JUMPSTART_LOGGER.warning(log_msg) + if gated_model_env_var is not None: default_environment_variables.update( {SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY: gated_model_env_var} diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 7c20c281f5..86630fcfb8 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -62,6 +62,7 @@ ) from sagemaker.jumpstart.utils import ( add_jumpstart_model_id_version_tags, + get_eula_message, update_dict_if_key_not_present, resolve_estimator_sagemaker_config_field, verify_model_region_and_return_specs, @@ -597,6 +598,26 @@ def _add_env_to_kwargs( value, ) + environment = getattr(kwargs, "environment", {}) or {} + if ( + environment.get(SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY) + and str(environment.get("accept_eula", "")).lower() != "true" + ): + model_specs = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + region=kwargs.region, + scope=JumpStartScriptScope.TRAINING, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, + ) + if model_specs.is_gated_model(): + raise ValueError( + "Need to define ‘accept_eula'='true' within Environment. " + f"{get_eula_message(model_specs, kwargs.region)}" + ) + return kwargs diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 3142596eba..6a389f385f 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -963,6 +963,10 @@ def use_training_model_artifact(self) -> bool: # otherwise, return true is a training model package is not set return len(self.training_model_package_artifact_uris or {}) == 0 + def is_gated_model(self) -> bool: + """Returns True if the model has a EULA key or the model bucket is gated.""" + return self.gated_bucket or self.hosting_eula_key is not None + def supports_incremental_training(self) -> bool: """Returns True if the model supports incremental training.""" return self.incremental_training_supported diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 71a8067a6f..62ccba7900 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -476,21 +476,25 @@ def update_inference_tags_with_jumpstart_training_tags( return inference_tags +def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str: + """Returns EULA message to display if one is available, else empty string.""" + if model_specs.hosting_eula_key is None: + return "" + return ( + f"Model '{model_specs.model_id}' requires accepting end-user license agreement (EULA). " + f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}." + f"amazonaws.com{'.cn' if region.startswith('cn-') else ''}" + f"/{model_specs.hosting_eula_key} for terms of use." + ) + + def emit_logs_based_on_model_specs( model_specs: JumpStartModelSpecs, region: str, s3_client: boto3.client ) -> None: """Emits logs based on model specs and region.""" if model_specs.hosting_eula_key: - constants.JUMPSTART_LOGGER.info( - "Model '%s' requires accepting end-user license agreement (EULA). " - "See https://%s.s3.%s.amazonaws.com%s/%s for terms of use.", - model_specs.model_id, - get_jumpstart_content_bucket(region=region), - region, - ".cn" if region.startswith("cn-") else "", - model_specs.hosting_eula_key, - ) + constants.JUMPSTART_LOGGER.info(get_eula_message(model_specs, region)) full_version: str = model_specs.version diff --git a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py index 38cc5ebbf3..cc1aad8a44 100644 --- a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py +++ b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py @@ -18,6 +18,7 @@ import pytest from sagemaker import environment_variables +from sagemaker.jumpstart.utils import get_jumpstart_gated_content_bucket from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec @@ -203,6 +204,70 @@ def test_jumpstart_sdk_environment_variables( ) +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_sdk_environment_variables_1_artifact_all_variants(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_special_model_spec + + model_id = "gemma-model-1-artifact" + region = "us-west-2" + + assert { + "SageMakerGatedModelS3Uri": f"s3://{get_jumpstart_gated_content_bucket(region)}/" + "huggingface-training/train-huggingface-llm-gemma-7b-instruct.tar.gz" + } == environment_variables.retrieve_default( + region=region, + model_id=model_id, + model_version="*", + include_aws_sdk_env_vars=False, + sagemaker_session=mock_session, + instance_type="ml.p3.2xlarge", + script="training", + ) + + +@patch("sagemaker.jumpstart.artifacts.environment_variables.JUMPSTART_LOGGER") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_sdk_environment_variables_no_gated_env_var_available( + patched_get_model_specs, patched_jumpstart_logger +): + + patched_get_model_specs.side_effect = get_special_model_spec + + model_id = "gemma-model" + region = "us-west-2" + + assert {} == environment_variables.retrieve_default( + region=region, + model_id=model_id, + model_version="*", + include_aws_sdk_env_vars=False, + sagemaker_session=mock_session, + instance_type="ml.p3.2xlarge", + script="training", + ) + + patched_jumpstart_logger.warning.assert_called_once_with( + "'gemma-model' does not support ml.p3.2xlarge instance type for " + "training. Please use one of the following instance types: " + "ml.g5.12xlarge, ml.g5.24xlarge, ml.g5.48xlarge, ml.p4d.24xlarge." + ) + + # assert that supported instance types succeed + assert { + "SageMakerGatedModelS3Uri": f"s3://{get_jumpstart_gated_content_bucket(region)}/" + "huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-7b-instruct.tar.gz" + } == environment_variables.retrieve_default( + region=region, + model_id=model_id, + model_version="*", + include_aws_sdk_env_vars=False, + sagemaker_session=mock_session, + instance_type="ml.g5.24xlarge", + script="training", + ) + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_sdk_environment_variables_instance_type_overrides(patched_get_model_specs): diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index ce8cc4ddfa..d7c4eb4921 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -14,6 +14,1242 @@ SPECIAL_MODEL_SPECS_DICT = { + "gemma-model": { + "model_id": "huggingface-llm-gemma-7b-instruct", + "url": "https://huggingface.co/google/gemma-7b-it", + "version": "1.1.0", + "min_sdk_version": "2.189.0", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface-llm", + "framework_version": "1.4.2", + "py_version": "py310", + "huggingface_transformers_version": "4.33.2", + }, + "hosting_artifact_key": "huggingface-llm/huggingface-llm-gemma-7b-instruct/artifacts/inference/v1.0.0/", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-llm/huggingface-llm-gemma-7b-i" + "nstruct/artifacts/inference-prepack/v1.0.0/", + "hosting_prepacked_artifact_version": "1.0.0", + "hosting_use_script_uri": False, + "hosting_eula_key": "fmhMetadata/terms/gemmaTerms.txt", + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [ + "accelerate==0.26.1", + "bitsandbytes==0.42.0", + "deepspeed==0.10.3", + "docstring-parser==0.15", + "flash_attn==2.5.5", + "ninja==1.11.1", + "packaging==23.2", + "peft==0.8.2", + "py_cpuinfo==9.0.0", + "rich==13.7.0", + "safetensors==0.4.2", + "sagemaker_jumpstart_huggingface_script_utilities==1.2.1", + "sagemaker_jumpstart_script_utilities==1.1.9", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + "shtab==1.6.5", + "tokenizers==0.15.1", + "transformers==4.38.1", + "trl==0.7.10", + "tyro==0.7.2", + ], + "training_vulnerabilities": [], + "deprecated": False, + "hyperparameters": [ + { + "name": "peft_type", + "type": "text", + "default": "lora", + "options": ["lora", "None"], + "scope": "algorithm", + }, + { + "name": "instruction_tuned", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "chat_dataset", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "epoch", + "type": "int", + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "learning_rate", + "type": "float", + "default": 0.0001, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "lora_r", + "type": "int", + "default": 64, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + {"name": "lora_alpha", "type": "int", "default": 16, "min": 0, "scope": "algorithm"}, + { + "name": "lora_dropout", + "type": "float", + "default": 0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + {"name": "bits", "type": "int", "default": 4, "scope": "algorithm"}, + { + "name": "double_quant", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "quant_type", + "type": "text", + "default": "nf4", + "options": ["fp4", "nf4"], + "scope": "algorithm", + }, + { + "name": "per_device_train_batch_size", + "type": "int", + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "per_device_eval_batch_size", + "type": "int", + "default": 2, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "warmup_ratio", + "type": "float", + "default": 0.1, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "train_from_scratch", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "fp16", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "bf16", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "evaluation_strategy", + "type": "text", + "default": "steps", + "options": ["steps", "epoch", "no"], + "scope": "algorithm", + }, + { + "name": "eval_steps", + "type": "int", + "default": 20, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "gradient_accumulation_steps", + "type": "int", + "default": 4, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "logging_steps", + "type": "int", + "default": 8, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "weight_decay", + "type": "float", + "default": 0.2, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "load_best_model_at_end", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "max_train_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_val_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "seed", + "type": "int", + "default": 10, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "max_input_length", + "type": "int", + "default": 2048, + "min": -1, + "scope": "algorithm", + }, + { + "name": "validation_split_ratio", + "type": "float", + "default": 0.2, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "train_data_split_seed", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "preprocessing_num_workers", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + {"name": "max_steps", "type": "int", "default": -1, "scope": "algorithm"}, + { + "name": "gradient_checkpointing", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "early_stopping_patience", + "type": "int", + "default": 3, + "min": 1, + "scope": "algorithm", + }, + { + "name": "early_stopping_threshold", + "type": "float", + "default": 0.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "adam_beta1", + "type": "float", + "default": 0.9, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_beta2", + "type": "float", + "default": 0.999, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_epsilon", + "type": "float", + "default": 1e-08, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "max_grad_norm", + "type": "float", + "default": 1.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "label_smoothing_factor", + "type": "float", + "default": 0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "logging_first_step", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "logging_nan_inf_filter", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "save_strategy", + "type": "text", + "default": "steps", + "options": ["no", "epoch", "steps"], + "scope": "algorithm", + }, + {"name": "save_steps", "type": "int", "default": 500, "min": 1, "scope": "algorithm"}, + {"name": "save_total_limit", "type": "int", "default": 1, "scope": "algorithm"}, + { + "name": "dataloader_drop_last", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "dataloader_num_workers", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "eval_accumulation_steps", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "auto_find_batch_size", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "lr_scheduler_type", + "type": "text", + "default": "constant_with_warmup", + "options": ["constant_with_warmup", "linear"], + "scope": "algorithm", + }, + {"name": "warmup_steps", "type": "int", "default": 0, "min": 0, "scope": "algorithm"}, + { + "name": "deepspeed", + "type": "text", + "default": "False", + "options": ["False"], + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "training_script_key": "source-directory-tarballs/huggingface/transfer_learning/llm/v1.1.1/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-tarballs/huggingface/transfer_learning/" + "llm/prepack/v1.1.1/sourcedir.tar.gz", + "training_prepacked_script_version": "1.1.1", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, + "training_artifact_key": "huggingface-training/train-huggingface-llm-gemma-7b-instruct.tar.gz", + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "8191", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "8192", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_BATCH_PREFILL_TOKENS", + "type": "text", + "default": "8191", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + {"Name": "huggingface-textgeneration:train-loss", "Regex": "'loss': ([0-9]+\\.[0-9]+)"}, + ], + "default_inference_instance_type": "ml.g5.12xlarge", + "supported_inference_instance_types": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "default_training_instance_type": "ml.g5.12xlarge", + "supported_training_instance_types": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 1200, + "container_startup_health_check_timeout": 1200, + }, + "estimator_kwargs": { + "encrypt_inter_container_traffic": True, + "disable_output_compression": True, + "max_run": 360000, + }, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_volume_size": 512, + "training_volume_size": 512, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/oasst_top/train/", + "validation_supported": True, + "fine_tuning_supported": True, + "resource_name_base": "hf-llm-gemma-7b-instruct", + "gated_bucket": True, + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/h" + "uggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-training:" + "2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + }, + "variants": { + "g4dn": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "huggingface-training/g4dn/v1.0.0/train-hugg" + "ingface-llm-gemma-7b-instruct.tar.gz" + }, + }, + "g5": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "huggingface-training/g5/v1.0.0/train-huggingf" + "ace-llm-gemma-7b-instruct.tar.gz" + }, + }, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "huggingface-training/p3dn/v1.0.0/train-hugg" + "ingface-llm-gemma-7b-instruct.tar.gz" + }, + }, + "p4d": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "huggingface-training/" + "p4d/v1.0.0/train-huggingface-llm-gemma-7b-instruct.tar.gz" + }, + }, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "hosting_resource_requirements": {"min_memory_mb": 98304, "num_accelerators": 4}, + "dynamic_container_deployment_supported": True, + }, + }, + "gemma-model-1-artifact": { + "model_id": "huggingface-llm-gemma-7b-instruct", + "url": "https://huggingface.co/google/gemma-7b-it", + "version": "1.1.0", + "min_sdk_version": "2.189.0", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface-llm", + "framework_version": "1.4.2", + "py_version": "py310", + "huggingface_transformers_version": "4.33.2", + }, + "hosting_artifact_key": "huggingface-llm/huggingface-llm-gemma-7b-instruct/artifacts/inference/v1.0.0/", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-llm/huggingface-llm-gemma-7b-i" + "nstruct/artifacts/inference-prepack/v1.0.0/", + "hosting_prepacked_artifact_version": "1.0.0", + "hosting_use_script_uri": False, + "hosting_eula_key": "fmhMetadata/terms/gemmaTerms.txt", + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [ + "accelerate==0.26.1", + "bitsandbytes==0.42.0", + "deepspeed==0.10.3", + "docstring-parser==0.15", + "flash_attn==2.5.5", + "ninja==1.11.1", + "packaging==23.2", + "peft==0.8.2", + "py_cpuinfo==9.0.0", + "rich==13.7.0", + "safetensors==0.4.2", + "sagemaker_jumpstart_huggingface_script_utilities==1.2.1", + "sagemaker_jumpstart_script_utilities==1.1.9", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + "shtab==1.6.5", + "tokenizers==0.15.1", + "transformers==4.38.1", + "trl==0.7.10", + "tyro==0.7.2", + ], + "training_vulnerabilities": [], + "deprecated": False, + "hyperparameters": [ + { + "name": "peft_type", + "type": "text", + "default": "lora", + "options": ["lora", "None"], + "scope": "algorithm", + }, + { + "name": "instruction_tuned", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "chat_dataset", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "epoch", + "type": "int", + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "learning_rate", + "type": "float", + "default": 0.0001, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "lora_r", + "type": "int", + "default": 64, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + {"name": "lora_alpha", "type": "int", "default": 16, "min": 0, "scope": "algorithm"}, + { + "name": "lora_dropout", + "type": "float", + "default": 0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + {"name": "bits", "type": "int", "default": 4, "scope": "algorithm"}, + { + "name": "double_quant", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "quant_type", + "type": "text", + "default": "nf4", + "options": ["fp4", "nf4"], + "scope": "algorithm", + }, + { + "name": "per_device_train_batch_size", + "type": "int", + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "per_device_eval_batch_size", + "type": "int", + "default": 2, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "warmup_ratio", + "type": "float", + "default": 0.1, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "train_from_scratch", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "fp16", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "bf16", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "evaluation_strategy", + "type": "text", + "default": "steps", + "options": ["steps", "epoch", "no"], + "scope": "algorithm", + }, + { + "name": "eval_steps", + "type": "int", + "default": 20, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "gradient_accumulation_steps", + "type": "int", + "default": 4, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "logging_steps", + "type": "int", + "default": 8, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "weight_decay", + "type": "float", + "default": 0.2, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "load_best_model_at_end", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "max_train_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_val_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "seed", + "type": "int", + "default": 10, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "max_input_length", + "type": "int", + "default": 2048, + "min": -1, + "scope": "algorithm", + }, + { + "name": "validation_split_ratio", + "type": "float", + "default": 0.2, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "train_data_split_seed", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "preprocessing_num_workers", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + {"name": "max_steps", "type": "int", "default": -1, "scope": "algorithm"}, + { + "name": "gradient_checkpointing", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "early_stopping_patience", + "type": "int", + "default": 3, + "min": 1, + "scope": "algorithm", + }, + { + "name": "early_stopping_threshold", + "type": "float", + "default": 0.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "adam_beta1", + "type": "float", + "default": 0.9, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_beta2", + "type": "float", + "default": 0.999, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_epsilon", + "type": "float", + "default": 1e-08, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "max_grad_norm", + "type": "float", + "default": 1.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "label_smoothing_factor", + "type": "float", + "default": 0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "logging_first_step", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "logging_nan_inf_filter", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "save_strategy", + "type": "text", + "default": "steps", + "options": ["no", "epoch", "steps"], + "scope": "algorithm", + }, + {"name": "save_steps", "type": "int", "default": 500, "min": 1, "scope": "algorithm"}, + {"name": "save_total_limit", "type": "int", "default": 1, "scope": "algorithm"}, + { + "name": "dataloader_drop_last", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "dataloader_num_workers", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "eval_accumulation_steps", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "auto_find_batch_size", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "lr_scheduler_type", + "type": "text", + "default": "constant_with_warmup", + "options": ["constant_with_warmup", "linear"], + "scope": "algorithm", + }, + {"name": "warmup_steps", "type": "int", "default": 0, "min": 0, "scope": "algorithm"}, + { + "name": "deepspeed", + "type": "text", + "default": "False", + "options": ["False"], + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "training_script_key": "source-directory-tarballs/huggingface/transfer_learning/llm/v1.1.1/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-tarballs/huggingface/transfer_learning/" + "llm/prepack/v1.1.1/sourcedir.tar.gz", + "training_prepacked_script_version": "1.1.1", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, + "training_artifact_key": "huggingface-training/train-huggingface-llm-gemma-7b-instruct.tar.gz", + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "8191", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "8192", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_BATCH_PREFILL_TOKENS", + "type": "text", + "default": "8191", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + {"Name": "huggingface-textgeneration:train-loss", "Regex": "'loss': ([0-9]+\\.[0-9]+)"}, + ], + "default_inference_instance_type": "ml.g5.12xlarge", + "supported_inference_instance_types": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "default_training_instance_type": "ml.g5.12xlarge", + "supported_training_instance_types": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 1200, + "container_startup_health_check_timeout": 1200, + }, + "estimator_kwargs": { + "encrypt_inter_container_traffic": True, + "disable_output_compression": True, + "max_run": 360000, + }, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_volume_size": 512, + "training_volume_size": 512, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/oasst_top/train/", + "validation_supported": True, + "fine_tuning_supported": True, + "resource_name_base": "hf-llm-gemma-7b-instruct", + "gated_bucket": True, + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/h" + "uggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-training:" + "2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + }, + "variants": { + "g4dn": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "huggingface-training/train-hugg" + "ingface-llm-gemma-7b-instruct.tar.gz" + }, + }, + "g5": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "huggingface-training/train-hugg" + "ingface-llm-gemma-7b-instruct.tar.gz" + }, + }, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "huggingface-training/train-hugg" + "ingface-llm-gemma-7b-instruct.tar.gz" + }, + }, + "p4d": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "huggingface-training/train-hugg" + "ingface-llm-gemma-7b-instruct.tar.gz" + }, + }, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "hosting_resource_requirements": {"min_memory_mb": 98304, "num_accelerators": 4}, + "dynamic_container_deployment_supported": True, + }, + }, "env-var-variant-model": { "model_id": "huggingface-llm-falcon-180b-bf16", "url": "https://huggingface.co/tiiuae/falcon-180B", diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 4fa18f31aa..fe4b122c4a 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -319,39 +319,21 @@ def test_gated_model_s3_uri( mock_session_estimator.return_value = sagemaker_session mock_session_model.return_value = sagemaker_session - JumpStartEstimator( - model_id=model_id, - environment={ - "accept_eula": "false", - "what am i": "doing", - "SageMakerGatedModelS3Uri": "none of your business", - }, - ) - - mock_estimator_init.assert_called_once_with( - instance_type="ml.p3.2xlarge", - instance_count=1, - image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117", - source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" - "meta/transfer_learning/textgeneration/v1.0.0/sourcedir.tar.gz", - entry_point="transfer_learning.py", - role=execution_role, - sagemaker_session=sagemaker_session, - max_run=360000, - enable_network_isolation=True, - encrypt_inter_container_traffic=True, - environment={ - "accept_eula": "false", - "what am i": "doing", - "SageMakerGatedModelS3Uri": "none of your business", - }, - tags=[ - { - "Key": "sagemaker-sdk:jumpstart-model-id", - "Value": "js-gated-artifact-trainable-model", + with pytest.raises(ValueError) as e: + JumpStartEstimator( + model_id=model_id, + environment={ + "accept_eula": "false", + "what am i": "doing", + "SageMakerGatedModelS3Uri": "none of your business", }, - {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "2.0.0"}, - ], + ) + assert str(e.value) == ( + "Need to define ‘accept_eula'='true' within Environment. " + "Model 'meta-textgeneration-llama-2-7b-f' requires accepting end-user " + "license agreement (EULA). See " + "https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com/fmhMetadata/eula/llamaEula.txt" + " for terms of use." ) mock_estimator_init.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 3ec6f8aec3..cb54722d48 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -905,13 +905,9 @@ def make_accept_eula_inference_spec(*largs, **kwargs): make_accept_eula_inference_spec(), "us-east-1", MOCK_CLIENT ) mocked_info_log.assert_any_call( - "Model '%s' requires accepting end-user license agreement (EULA). " - "See https://%s.s3.%s.amazonaws.com%s/%s for terms of use.", - "pytorch-eqa-bert-base-cased", - "jumpstart-cache-prod-us-east-1", - "us-east-1", - "", - "read/the/fine/print.txt", + "Model 'pytorch-eqa-bert-base-cased' requires accepting end-user license agreement (EULA). " + "See https://jumpstart-cache-prod-us-east-1.s3.us-east-1.amazonaws.com/read/the/fine/print.txt" + " for terms of use.", ) From 377be874be5f2cf378e3abd33a4dc2e80120358e Mon Sep 17 00:00:00 2001 From: evakravi <69981223+evakravi@users.noreply.github.com> Date: Wed, 13 Mar 2024 12:13:48 -0400 Subject: [PATCH 08/20] fix: sagemaker session region not being used (#4469) * fix: sagemaker session region not being used * chore: add unit tests * fix: remove all JUMPSTART_DEFAULT_REGION_NAME default arguments * chore: use get_region_fallback throughout * chore: remove unnecessary if statement * chore: remove unnecessary if statement (2) --------- Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> --- .../artifacts/environment_variables.py | 12 +- .../jumpstart/artifacts/hyperparameters.py | 7 +- .../jumpstart/artifacts/image_uris.py | 7 +- .../artifacts/incremental_training.py | 7 +- .../jumpstart/artifacts/instance_types.py | 12 +- src/sagemaker/jumpstart/artifacts/kwargs.py | 22 ++-- .../jumpstart/artifacts/metric_definitions.py | 7 +- .../jumpstart/artifacts/model_packages.py | 12 +- .../jumpstart/artifacts/model_uris.py | 12 +- src/sagemaker/jumpstart/artifacts/payloads.py | 7 +- .../jumpstart/artifacts/predictors.py | 22 ++-- .../jumpstart/artifacts/resource_names.py | 7 +- .../artifacts/resource_requirements.py | 7 +- .../jumpstart/artifacts/script_uris.py | 12 +- src/sagemaker/jumpstart/cache.py | 105 ++++++++---------- src/sagemaker/jumpstart/estimator.py | 2 +- src/sagemaker/jumpstart/factory/estimator.py | 8 +- src/sagemaker/jumpstart/factory/model.py | 6 +- src/sagemaker/jumpstart/model.py | 2 +- src/sagemaker/jumpstart/notebook_utils.py | 45 +++++--- src/sagemaker/jumpstart/payload_utils.py | 8 +- src/sagemaker/jumpstart/utils.py | 45 +++++++- src/sagemaker/jumpstart/validators.py | 16 +-- .../jumpstart/test_accept_types.py | 3 +- .../jumpstart/test_content_types.py | 3 +- .../jumpstart/test_deserializers.py | 3 +- .../jumpstart/test_default.py | 3 +- .../hyperparameters/jumpstart/test_default.py | 3 +- .../jumpstart/test_validate.py | 3 +- .../image_uris/jumpstart/test_common.py | 3 +- .../jumpstart/test_instance_types.py | 5 +- .../jumpstart/estimator/test_estimator.py | 61 +++++++++- .../sagemaker/jumpstart/model/test_model.py | 65 ++++++++++- .../sagemaker/jumpstart/test_artifacts.py | 6 +- tests/unit/sagemaker/jumpstart/test_utils.py | 49 ++++++++ .../jumpstart/test_default.py | 6 +- .../model_uris/jumpstart/test_common.py | 3 +- .../jumpstart/test_resource_requirements.py | 4 +- .../script_uris/jumpstart/test_common.py | 3 +- .../serializers/jumpstart/test_serializers.py | 5 +- 40 files changed, 437 insertions(+), 181 deletions(-) diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index 006559852c..fa5ae8900b 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -15,7 +15,6 @@ from typing import Callable, Dict, Optional, Set from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, ) @@ -24,6 +23,7 @@ ) from sagemaker.jumpstart.utils import ( get_jumpstart_gated_content_bucket, + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -72,8 +72,9 @@ def _retrieve_default_environment_variables( dict: the inference environment variables to use for the model. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -198,8 +199,9 @@ def _retrieve_gated_model_uri_env_var_value( ValueError: If the model specs specified are invalid. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index e9e6f613f8..d19530ecfb 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -15,13 +15,13 @@ from typing import Dict, Optional from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, VariableScope, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -70,8 +70,9 @@ def _retrieve_default_hyperparameters( dict: the hyperparameters to use for the model. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 6ea1ca84a1..9d19d5e069 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -17,13 +17,13 @@ from sagemaker import image_uris from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, ModelFramework, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -104,8 +104,9 @@ def _retrieve_image_uri( known security vulnerabilities. DeprecatedJumpStartModelError: If the version of the model is deprecated. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index 753a911422..1b3c6f4b29 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -15,12 +15,12 @@ from typing import Optional from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -58,8 +58,9 @@ def _model_supports_incremental_training( bool: the support status for incremental training. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 608303c5e6..e7c9c5911d 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -18,13 +18,13 @@ from sagemaker.jumpstart.exceptions import NO_AVAILABLE_INSTANCES_ERROR_MSG from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, JumpStartModelType, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -76,8 +76,9 @@ def _retrieve_default_instance_type( specified region due to lack of supported computing instances. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -163,8 +164,9 @@ def _retrieve_instance_types( specified region due to lack of supported computing instances. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index c15f686805..9cd152b0bb 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -18,13 +18,13 @@ from sagemaker.utils import volume_size_supported from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, JumpStartModelType, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) @@ -62,8 +62,9 @@ def _retrieve_model_init_kwargs( dict: the kwargs to use for the use case. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -121,8 +122,9 @@ def _retrieve_model_deploy_kwargs( dict: the kwargs to use for the use case. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -176,8 +178,9 @@ def _retrieve_estimator_init_kwargs( dict: the kwargs to use for the use case. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -233,8 +236,9 @@ def _retrieve_estimator_fit_kwargs( dict: the kwargs to use for the use case. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index b6f6019641..57f66155c7 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -16,12 +16,12 @@ from typing import Dict, List, Optional from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -62,8 +62,9 @@ def _retrieve_default_training_metric_definitions( list: the default training metric definitions to use for the model or None. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index 5c8a2488c6..aa22351771 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -15,9 +15,9 @@ from typing import Optional from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.jumpstart.enums import ( @@ -65,8 +65,9 @@ def _retrieve_model_package_arn( str: the model package arn to use for the model or None. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -149,8 +150,9 @@ def _retrieve_model_package_model_artifact_s3_uri( if scope == JumpStartScriptScope.TRAINING: - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index c41f0a75b7..6bb2e576fc 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -18,7 +18,6 @@ from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, @@ -26,6 +25,7 @@ from sagemaker.jumpstart.utils import ( get_jumpstart_content_bucket, get_jumpstart_gated_content_bucket, + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -129,8 +129,9 @@ def _retrieve_model_uri( known security vulnerabilities. DeprecatedJumpStartModelError: If the version of the model is deprecated. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -206,8 +207,9 @@ def _model_supports_training_model_uri( bool: the support status for model uri with training. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 0424145119..3359e32732 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -16,7 +16,6 @@ from typing import Dict, Optional from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, @@ -24,6 +23,7 @@ ) from sagemaker.jumpstart.types import JumpStartSerializablePayload from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -63,8 +63,9 @@ def _retrieve_example_payloads( to the serializable payload object. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index e9e0e8dfde..4f6dfe1fe3 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -20,7 +20,6 @@ CONTENT_TYPE_TO_SERIALIZER_TYPE_MAP, DEFAULT_JUMPSTART_SAGEMAKER_SESSION, DESERIALIZER_TYPE_TO_CLASS_MAP, - JUMPSTART_DEFAULT_REGION_NAME, SERIALIZER_TYPE_TO_CLASS_MAP, ) from sagemaker.jumpstart.enums import ( @@ -29,6 +28,7 @@ JumpStartModelType, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -309,8 +309,9 @@ def _retrieve_default_content_type( str: the default content type to use for the model. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -359,8 +360,9 @@ def _retrieve_default_accept_type( str: the default accept type to use for the model. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -410,8 +412,9 @@ def _retrieve_supported_accept_types( list: the supported accept types to use for the model. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -461,8 +464,9 @@ def _retrieve_supported_content_types( list: the supported content types to use for the model. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index 60af520a6e..cffd46d043 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -15,13 +15,13 @@ from typing import Optional from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, JumpStartModelType, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -60,8 +60,9 @@ def _retrieve_resource_name_base( str: the default resource name. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 9f01a7af77..369acac85f 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -17,13 +17,13 @@ from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, JumpStartModelType, ) from sagemaker.jumpstart.utils import ( + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -89,8 +89,9 @@ def _retrieve_default_resources( retrieve default resource requirements """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index c1b037ce61..f69732d2e0 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -17,13 +17,13 @@ from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, ) from sagemaker.jumpstart.utils import ( get_jumpstart_content_bucket, + get_region_fallback, verify_model_region_and_return_specs, ) from sagemaker.session import Session @@ -71,8 +71,9 @@ def _retrieve_script_uri( known security vulnerabilities. DeprecatedJumpStartModelError: If the version of the model is deprecated. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -132,8 +133,9 @@ def _model_supports_inference_script_uri( bool: the support status for script uri with inference. """ - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 7682ab3817..fff421ab32 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -26,7 +26,6 @@ ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, - JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, MODEL_ID_LIST_WEB_URL, MODEL_TYPE_TO_MANIFEST_MAP, @@ -62,24 +61,21 @@ class JumpStartModelsCache: for launching JumpStart models from the SageMaker SDK. """ - # fmt: off def __init__( self, - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, max_s3_cache_items: int = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, - s3_cache_expiration_horizon: datetime.timedelta = - JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON, - max_semantic_version_cache_items: int = - JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, - semantic_version_cache_expiration_horizon: datetime.timedelta = - JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, - manifest_file_s3_key: str = - JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, + s3_cache_expiration_horizon: datetime.timedelta = JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON, + max_semantic_version_cache_items: int = JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, + semantic_version_cache_expiration_horizon: datetime.timedelta = ( + JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON + ), + manifest_file_s3_key: str = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, proprietary_manifest_s3_key: str = JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, s3_bucket_name: Optional[str] = None, s3_client_config: Optional[botocore.config.Config] = None, s3_client: Optional[boto3.client] = None, - ) -> None: # fmt: on + ) -> None: """Initialize a ``JumpStartModelsCache`` instance. Args: @@ -102,7 +98,10 @@ def __init__( s3_client (Optional[boto3.client]): s3 client to use. Default: None. """ - self._region = region + self._region = region or utils.get_region_fallback( + s3_bucket_name=s3_bucket_name, s3_client=s3_client + ) + self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue]( max_cache_items=max_s3_cache_items, expiration_horizon=s3_cache_expiration_horizon, @@ -165,9 +164,7 @@ def set_manifest_file_s3_key( } property_name = file_mapping.get(file_type) if not property_name: - raise ValueError( - self._file_type_error_msg(file_type, manifest_only=True) - ) + raise ValueError(self._file_type_error_msg(file_type, manifest_only=True)) if key != property_name: setattr(self, property_name, key) self.clear() @@ -180,9 +177,7 @@ def get_manifest_file_s3_key( return self._manifest_file_s3_key if file_type == JumpStartS3FileType.PROPRIETARY_MANIFEST: return self._proprietary_manifest_s3_key - raise ValueError( - self._file_type_error_msg(file_type, manifest_only=True) - ) + raise ValueError(self._file_type_error_msg(file_type, manifest_only=True)) def set_s3_bucket_name(self, s3_bucket_name: str) -> None: """Set s3 bucket used for cache.""" @@ -235,7 +230,8 @@ def _model_id_retrieval_function( sm_version = utils.get_sagemaker_version() manifest = self._s3_cache.get( JumpStartCachedS3ContentKey( - MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]) + MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type] + ) )[0].formatted_content versions_compatible_with_sagemaker = [ @@ -252,7 +248,8 @@ def _model_id_retrieval_function( return JumpStartVersionedModelId(model_id, sm_compatible_model_version) versions_incompatible_with_sagemaker = [ - Version(header.version) for header in manifest.values() # type: ignore + Version(header.version) + for header in manifest.values() # type: ignore if header.model_id == model_id ] sm_incompatible_model_version = self._select_version( @@ -282,9 +279,7 @@ def _model_id_retrieval_function( raise KeyError(error_msg) error_msg = f"Unable to find model manifest for '{model_id}' with version '{version}'. " - error_msg += ( - f"Visit {MODEL_ID_LIST_WEB_URL} for updated list of models. " - ) + error_msg += f"Visit {MODEL_ID_LIST_WEB_URL} for updated list of models. " other_model_id_version = None if model_type == JumpStartModelType.OPEN_WEIGHTS: @@ -293,19 +288,17 @@ def _model_id_retrieval_function( ) # all versions here are incompatible with sagemaker elif model_type == JumpStartModelType.PROPRIETARY: all_possible_model_id_version = [ - header.version for header in manifest.values() # type: ignore + header.version + for header in manifest.values() # type: ignore if header.model_id == model_id ] other_model_id_version = ( - None - if not all_possible_model_id_version - else all_possible_model_id_version[0] + None if not all_possible_model_id_version else all_possible_model_id_version[0] ) if other_model_id_version is not None: error_msg += ( - f"Consider using model ID '{model_id}' with version " - f"'{other_model_id_version}'." + f"Consider using model ID '{model_id}' with version " f"'{other_model_id_version}'." ) else: possible_model_ids = [header.model_id for header in manifest.values()] # type: ignore @@ -347,15 +340,15 @@ def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list], def _is_local_metadata_mode(self) -> bool: """Returns True if the cache should use local metadata mode, based off env variables.""" - return (ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ - and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE]) - and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ - and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE])) + return ( + ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ + and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE]) + and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ + and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE]) + ) def _get_json_file( - self, - key: str, - filetype: JumpStartS3FileType + self, key: str, filetype: JumpStartS3FileType ) -> Tuple[Union[dict, list], Optional[str]]: """Returns json file either from s3 or local file system. @@ -379,21 +372,19 @@ def _get_json_md5_hash(self, key: str): return self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=key)["ETag"] def _get_json_file_from_local_override( - self, - key: str, - filetype: JumpStartS3FileType + self, key: str, filetype: JumpStartS3FileType ) -> Union[dict, list]: """Reads json file from local filesystem and returns data.""" if filetype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: - metadata_local_root = ( - os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE] - ) + metadata_local_root = os.environ[ + ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE + ] elif filetype == JumpStartS3FileType.OPEN_WEIGHT_SPECS: metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE] else: raise ValueError(f"Unsupported file type for local override: {filetype}") file_path = os.path.join(metadata_local_root, key) - with open(file_path, 'r') as f: + with open(file_path, "r") as f: data = json.load(f) return data @@ -437,9 +428,7 @@ def _retrieval_function( model_specs = JumpStartModelSpecs(formatted_body) utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client) return JumpStartCachedS3ContentValue(formatted_content=model_specs) - raise ValueError( - self._file_type_error_msg(file_type) - ) + raise ValueError(self._file_type_error_msg(file_type)) def get_manifest( self, @@ -448,7 +437,8 @@ def get_manifest( """Return entire JumpStart models manifest.""" manifest_dict = self._s3_cache.get( JumpStartCachedS3ContentKey( - MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]) + MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type] + ) )[0].formatted_content manifest = list(manifest_dict.values()) # type: ignore return manifest @@ -505,16 +495,14 @@ def _select_version( except InvalidSpecifier: raise KeyError(f"Bad semantic version: {version_str}") available_versions_filtered = list(spec.filter(available_versions)) - return ( - str(max(available_versions_filtered)) if available_versions_filtered != [] else None - ) + return str(max(available_versions_filtered)) if available_versions_filtered != [] else None def _get_header_impl( self, model_id: str, semantic_version_str: str, attempt: int = 0, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelHeader: """Lower-level function to return header. @@ -537,7 +525,8 @@ def _get_header_impl( manifest = self._s3_cache.get( JumpStartCachedS3ContentKey( - MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]) + MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type] + ) )[0].formatted_content try: @@ -553,7 +542,7 @@ def get_specs( self, model_id: str, version_str: str, - model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """Return specs for a given JumpStart model ID and semantic version. @@ -566,16 +555,12 @@ def get_specs( header = self.get_header(model_id, version_str, model_type) spec_key = header.spec_key specs, cache_hit = self._s3_cache.get( - JumpStartCachedS3ContentKey( - MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key - ) + JumpStartCachedS3ContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key) ) if not cache_hit and "*" in version_str: JUMPSTART_LOGGER.warning( - get_wildcard_model_version_msg( - header.model_id, version_str, header.version - ) + get_wildcard_model_version_msg(header.model_id, version_str, header.version) ) return specs.formatted_content diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 4dada409f5..bac076ea4a 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -508,7 +508,7 @@ def _validate_model_id_and_get_type_hook(): return validate_model_id_and_get_type( model_id=model_id, model_version=model_version, - region=region, + region=region or getattr(sagemaker_session, "boto_region_name", None), script=JumpStartScriptScope.TRAINING, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 86630fcfb8..875ec9d003 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -192,8 +192,8 @@ def get_init_kwargs( estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(estimator_init_kwargs) - estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_sagemaker_session_to_kwargs(estimator_init_kwargs) + estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_instance_type_and_count_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_image_uri_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_model_uri_to_kwargs(estimator_init_kwargs) @@ -393,7 +393,9 @@ def get_deploy_kwargs( def _add_region_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: """Sets region in kwargs based on default or override, returns full kwargs.""" - kwargs.region = kwargs.region or JUMPSTART_DEFAULT_REGION_NAME + kwargs.region = ( + kwargs.region or kwargs.sagemaker_session.boto_region_name or JUMPSTART_DEFAULT_REGION_NAME + ) return kwargs @@ -507,6 +509,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + region=kwargs.region, instance_type=kwargs.instance_type, ) @@ -553,6 +556,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart model_version=kwargs.model_version, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + region=kwargs.region, sagemaker_session=kwargs.sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 63b4898877..28746990e3 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -129,7 +129,9 @@ def get_default_predictor( def _add_region_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets region kwargs based on default or override, returns full kwargs.""" - kwargs.region = kwargs.region or JUMPSTART_DEFAULT_REGION_NAME + kwargs.region = ( + kwargs.region or kwargs.sagemaker_session.boto_region_name or JUMPSTART_DEFAULT_REGION_NAME + ) return kwargs @@ -758,8 +760,8 @@ def get_init_kwargs( model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_sagemaker_session_to_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index a8f4c0e9fc..4529bc11b9 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -285,7 +285,7 @@ def _validate_model_id_and_type(): return validate_model_id_and_get_type( model_id=model_id, model_version=model_version, - region=region, + region=region or getattr(sagemaker_session, "boto_region_name", None), script=JumpStartScriptScope.INFERENCE, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 485354e802..85a041379a 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -23,7 +23,6 @@ from sagemaker.jumpstart import accessors from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, PROPRIETARY_MODEL_SPEC_PREFIX, ) from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType @@ -38,6 +37,7 @@ from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs from sagemaker.jumpstart.utils import ( get_jumpstart_content_bucket, + get_region_fallback, get_sagemaker_version, verify_model_region_and_return_specs, validate_model_id_and_get_type, @@ -156,7 +156,7 @@ def extract_model_type_filter_representation(spec_key: str) -> str: def list_jumpstart_tasks( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """List tasks for JumpStart, and optionally apply filters to result. @@ -168,11 +168,14 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin (e.g. ``"task == ic"``). If this argument is not supplied, all tasks will be listed. (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding - models. (Default: JUMPSTART_DEFAULT_REGION_NAME). + models. (Default: None). sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) tasks: Set[str] = set() for model_id, _ in _generate_jumpstart_model_versions( filter=filter, region=region, sagemaker_session=sagemaker_session @@ -184,7 +187,7 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin def list_jumpstart_frameworks( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """List frameworks for JumpStart, and optionally apply filters to result. @@ -196,11 +199,14 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin (eg. ``"task == ic"``). If this argument is not supplied, all frameworks will be listed. (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding - models. (Default: JUMPSTART_DEFAULT_REGION_NAME). + models. (Default: None). sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) frameworks: Set[str] = set() for model_id, _ in _generate_jumpstart_model_versions( filter=filter, region=region, sagemaker_session=sagemaker_session @@ -212,7 +218,7 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin def list_jumpstart_scripts( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """List scripts for JumpStart, and optionally apply filters to result. @@ -224,10 +230,13 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin (e.g. ``"task == ic"``). If this argument is not supplied, all scripts will be listed. (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding - models. (Default: JUMPSTART_DEFAULT_REGION_NAME). + models. (Default: None). sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) if (isinstance(filter, Constant) and filter.resolved_value == BooleanValues.TRUE) or ( isinstance(filter, str) and filter.lower() == BooleanValues.TRUE.lower() ): @@ -255,7 +264,7 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin def list_jumpstart_models( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, list_incomplete_models: bool = False, list_old_models: bool = False, list_versions: bool = False, @@ -270,7 +279,7 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin (e.g. ``"task == ic"``). If this argument is not supplied, all models will be listed. (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding - models. (Default: JUMPSTART_DEFAULT_REGION_NAME). + models. (Default: None). list_incomplete_models (bool): Optional. If a model does not contain metadata fields requested by the filter, and the filter cannot be resolved to a include/not include, whether the model should be included. By default, these models are omitted from results. @@ -283,6 +292,9 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_id_version_dict: Dict[str, List[str]] = dict() for model_id, version in _generate_jumpstart_model_versions( filter=filter, @@ -312,7 +324,7 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, list_incomplete_models: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Generator: @@ -325,7 +337,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin (e.g. ``"task == ic"``). If this argument is not supplied, all models will be generated. (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding - models. (Default: JUMPSTART_DEFAULT_REGION_NAME). + models. (Default: None). list_incomplete_models (bool): Optional. If a model does not contain metadata fields requested by the filter, and the filter cannot be resolved to a include/not include, whether the model should be included. By default, these models are omitted from @@ -334,6 +346,10 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) + prop_models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( region=region, s3_client=sagemaker_session.s3_client, @@ -484,7 +500,7 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, def get_model_url( model_id: str, model_version: str, - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieve web url describing pretrained model. @@ -493,7 +509,7 @@ def get_model_url( model_id (str): The model ID for which to retrieve the url. model_version (str): The model version for which to retrieve the url. region (str): Optional. The region from which to retrieve metadata. - (Default: JUMPSTART_DEFAULT_REGION_NAME) + (Default: None) sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to retrieve the model url. """ @@ -504,6 +520,9 @@ def get_model_url( sagemaker_session=sagemaker_session, ) + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, diff --git a/src/sagemaker/jumpstart/payload_utils.py b/src/sagemaker/jumpstart/payload_utils.py index 242118c56e..595f801598 100644 --- a/src/sagemaker/jumpstart/payload_utils.py +++ b/src/sagemaker/jumpstart/payload_utils.py @@ -22,12 +22,12 @@ from sagemaker.jumpstart.artifacts.payloads import _retrieve_example_payloads from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import MIMEType from sagemaker.jumpstart.types import JumpStartSerializablePayload from sagemaker.jumpstart.utils import ( get_jumpstart_content_bucket, + get_region_fallback, ) from sagemaker.session import Session @@ -125,12 +125,14 @@ class PayloadSerializer: def __init__( self, bucket: Optional[str] = None, - region: str = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, s3_client: Optional[boto3.client] = None, ) -> None: """Initializes PayloadSerializer object.""" self.bucket = bucket or get_jumpstart_content_bucket() - self.region = region + self.region = region or get_region_fallback( + s3_client=s3_client, + ) self.s3_client = s3_client def get_bytes_payload_with_s3_references( diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 62ccba7900..5f51173b24 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -540,7 +540,7 @@ def verify_model_region_and_return_specs( model_id: Optional[str], version: Optional[str], scope: Optional[str], - region: str, + region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -576,6 +576,10 @@ def verify_model_region_and_return_specs( DeprecatedJumpStartModelError: If the version of the model is deprecated. """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) + if scope is None: raise ValueError( "Must specify `model_scope` argument to retrieve model " @@ -842,3 +846,42 @@ def get_jumpstart_model_id_version_from_resource_arn( model_version = model_version_from_tag return model_id, model_version + + +def get_region_fallback( + s3_bucket_name: Optional[str] = None, + s3_client: Optional[boto3.client] = None, + sagemaker_session: Optional[Session] = None, +) -> str: + """Returns region to use for JumpStart functionality implicitly via session objects.""" + regions_in_s3_bucket_name: Set[str] = { + region + for region in constants.JUMPSTART_REGION_NAME_SET + if s3_bucket_name is not None + if region in s3_bucket_name + } + regions_in_s3_client_endpoint_url: Set[str] = { + region + for region in constants.JUMPSTART_REGION_NAME_SET + if s3_client is not None + if region in s3_client._endpoint.host + } + + regions_in_sagemaker_session: Set[str] = { + region + for region in constants.JUMPSTART_REGION_NAME_SET + if sagemaker_session + if region == sagemaker_session.boto_region_name + } + + combined_regions = regions_in_s3_client_endpoint_url.union( + regions_in_s3_bucket_name, regions_in_sagemaker_session + ) + + if len(combined_regions) > 1: + raise ValueError("Unable to resolve a region name from the s3 bucket and client provided.") + + if len(combined_regions) == 0: + return constants.JUMPSTART_DEFAULT_REGION_NAME + + return list(combined_regions)[0] diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index 3199e5fc2e..c7098a1185 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -14,7 +14,6 @@ from __future__ import absolute_import from typing import Any, Dict, List, Optional from sagemaker import session -from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.jumpstart.enums import ( HyperparameterValidationMode, @@ -24,7 +23,7 @@ ) from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError from sagemaker.jumpstart.types import JumpStartHyperparameter -from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.utils import get_region_fallback, verify_model_region_and_return_specs def _validate_hyperparameter( @@ -168,7 +167,7 @@ def validate_hyperparameters( model_version: str, hyperparameters: Dict[str, Any], validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, - region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME, + region: Optional[str] = None, sagemaker_session: Optional[session.Session] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -184,8 +183,7 @@ def validate_hyperparameters( to this function will be validated, the missing hyperparameters will be ignored. If set to``VALIDATE_ALGORITHM``, all algorithm hyperparameters will be validated. If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated. - region (str): Region for which to validate hyperparameters. (Default: JumpStart - default region). + region (str): Region for which to validate hyperparameters. (Default: None). sagemaker_session (Optional[Session]): Custom SageMaker Session to use. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -202,11 +200,15 @@ def validate_hyperparameters( """ + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) if validation_mode is None: validation_mode = HyperparameterValidationMode.VALIDATE_PROVIDED - if region is None: - region = JUMPSTART_DEFAULT_REGION_NAME + region = region or get_region_fallback( + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py index 49c18beec2..11165a0625 100644 --- a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py +++ b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py @@ -23,7 +23,8 @@ mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +region = "us-west-2" +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") diff --git a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py index 7765d6eaad..d116c8121b 100644 --- a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py +++ b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py @@ -22,7 +22,8 @@ from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +region = "us-west-2" +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") diff --git a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py index 5328533da5..f0102068e7 100644 --- a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py +++ b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py @@ -22,9 +22,10 @@ from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec +region = "us-west-2" mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") diff --git a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py index cc1aad8a44..5f00f93abf 100644 --- a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py +++ b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py @@ -25,7 +25,8 @@ mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +region = "us-west-2" +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py index a13fba87ae..40ee4978cf 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py @@ -24,7 +24,8 @@ mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +region = "us-west-2" +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index 7a5df4ac93..07418f8ddb 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -23,8 +23,9 @@ from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec +region = "us-west-2" mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index 6c80c97f33..88b95b9403 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -37,8 +37,9 @@ def test_jumpstart_common_image_uri( patched_get_model_specs.side_effect = get_spec_from_base_spec patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS + region = "us-west-2" mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + mock_session = Mock(s3_client=mock_client, boto_region_name=region) image_uris.retrieve( framework=None, diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index 982c7f1702..2e51afd3f7 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -34,7 +34,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode region = "us-west-2" mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + mock_session = Mock(s3_client=mock_client, boto_region_name=region) default_training_instance_types = instance_types.retrieve_default( region=region, @@ -178,7 +178,8 @@ def test_jumpstart_inference_instance_type_variants(patched_get_model_specs): patched_get_model_specs.side_effect = get_special_model_spec mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) model_id, model_version = "inference-instance-types-variant-model", "*" region = "us-west-2" diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index fe4b122c4a..1e048ef0dd 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -16,6 +16,7 @@ from unittest import mock import unittest from inspect import signature +from mock import Mock import pytest from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig @@ -30,12 +31,16 @@ from sagemaker.jumpstart.artifacts.metric_definitions import ( _retrieve_default_training_metric_definitions, ) -from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + JUMPSTART_DEFAULT_REGION_NAME, +) from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag, JumpStartModelType from sagemaker.jumpstart.estimator import JumpStartEstimator from sagemaker.jumpstart.utils import get_jumpstart_content_bucket +from sagemaker.session import Session from sagemaker.session_settings import SessionSettings from tests.integ.sagemaker.jumpstart.utils import get_training_dataset_for_model_and_version from sagemaker.model import Model @@ -44,6 +49,7 @@ get_special_model_spec, overwrite_dictionary, ) +import boto3 execution_role = "fake role! do not use!" @@ -1553,7 +1559,8 @@ def test_training_passes_session_to_deploy( mock_get_model_specs.side_effect = get_special_model_spec mock_role = f"dsfsdfsd{time.time()}" - mock_sagemaker_session = mock.MagicMock(sagemaker_config={}) + region = "us-west-2" + mock_sagemaker_session = mock.MagicMock(sagemaker_config={}, boto_region_name=region) mock_sagemaker_session.get_caller_identity_arn = lambda: mock_role estimator = JumpStartEstimator( @@ -1758,6 +1765,56 @@ def test_model_artifact_variant_estimator( ], ) + @mock.patch("sagemaker.jumpstart.estimator.get_default_predictor") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_jumpstart_estimator_session( + self, + mock_get_model_specs: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, + mock_deploy, + mock_fit, + mock_init, + get_default_predictor, + ): + + mock_validate_model_id_and_get_type.return_value = True + + model_id, _ = "js-trainable-model", "*" + + mock_get_model_specs.side_effect = get_special_model_spec + + region = "eu-west-1" # some non-default region + + if region == JUMPSTART_DEFAULT_REGION_NAME: + region = "us-west-2" + + session = Session(boto_session=boto3.session.Session(region_name=region)) + + assert session.boto_region_name != JUMPSTART_DEFAULT_REGION_NAME + + session.get_caller_identity_arn = Mock(return_value="blah") + + estimator = JumpStartEstimator(model_id=model_id, sagemaker_session=session) + estimator.fit() + + estimator.deploy() + + assert len(mock_get_model_specs.call_args_list) > 1 + + regions = {call[1]["region"] for call in mock_get_model_specs.call_args_list} + + assert len(regions) == 1 + assert list(regions)[0] == region + + s3_clients = {call[1]["s3_client"] for call in mock_get_model_specs.call_args_list} + assert len(s3_clients) == 1 + assert list(s3_clients)[0] == session.s3_client + def test_jumpstart_estimator_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index ba4ba0bb13..c4a96d4120 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -15,18 +15,22 @@ from typing import Optional, Set from unittest import mock import unittest -from mock import MagicMock +from mock import MagicMock, Mock import pytest from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.jumpstart.artifacts.environment_variables import ( _retrieve_default_environment_variables, ) -from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + JUMPSTART_DEFAULT_REGION_NAME, +) from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag, JumpStartModelType from sagemaker.jumpstart.model import JumpStartModel from sagemaker.model import Model from sagemaker.predictor import Predictor +from sagemaker.session import Session from sagemaker.session_settings import SessionSettings from sagemaker.enums import EndpointType from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements @@ -38,6 +42,7 @@ get_special_model_spec_for_inference_component_based_endpoint, get_prototype_manifest, ) +import boto3 execution_role = "fake role! do not use!" region = "us-west-2" @@ -950,7 +955,7 @@ def test_jumpstart_model_tags( mock_get_model_specs.side_effect = get_special_model_spec - mock_session = MagicMock(sagemaker_config={}) + mock_session = MagicMock(sagemaker_config={}, boto_region_name="us-west-2") model = JumpStartModel(model_id=model_id, sagemaker_session=mock_session) @@ -987,7 +992,9 @@ def test_jumpstart_model_tags_disabled( mock_get_model_specs.side_effect = get_special_model_spec settings = SessionSettings(include_jumpstart_tags=False) - mock_session = MagicMock(sagemaker_config={}, settings=settings) + mock_session = MagicMock( + sagemaker_config={}, settings=settings, boto_region_name="us-west-2" + ) model = JumpStartModel(model_id=model_id, sagemaker_session=mock_session) @@ -1018,7 +1025,7 @@ def test_jumpstart_model_package_arn( mock_get_model_specs.side_effect = get_special_model_spec - mock_session = MagicMock(sagemaker_config={}) + mock_session = MagicMock(sagemaker_config={}, boto_region_name="us-west-2") model = JumpStartModel(model_id=model_id, sagemaker_session=mock_session) @@ -1053,7 +1060,7 @@ def test_jumpstart_model_package_arn_override( mock_get_model_specs.side_effect = get_special_model_spec - mock_session = MagicMock(sagemaker_config={}) + mock_session = MagicMock(sagemaker_config={}, boto_region_name="us-west-2") model_package_arn = ( "arn:aws:sagemaker:us-west-2:867530986753:model-package/" @@ -1312,6 +1319,52 @@ def test_model_registry_accept_and_response_types( response_types=["application/json;verbose", "application/json"], ) + @mock.patch("sagemaker.jumpstart.model.get_default_predictor") + @mock.patch("sagemaker.jumpstart.model.Model.__init__") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_jumpstart_model_session( + self, + mock_get_model_specs: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, + mock_deploy, + mock_init, + get_default_predictor, + ): + + mock_validate_model_id_and_get_type.return_value = True + + model_id, _ = "model_data_s3_prefix_model", "*" + + mock_get_model_specs.side_effect = get_special_model_spec + + region = "eu-west-1" # some non-default region + + if region == JUMPSTART_DEFAULT_REGION_NAME: + region = "us-west-2" + + session = Session(boto_session=boto3.session.Session(region_name=region)) + + assert session.boto_region_name != JUMPSTART_DEFAULT_REGION_NAME + + session.get_caller_identity_arn = Mock(return_value="blah") + + model = JumpStartModel(model_id=model_id, sagemaker_session=session) + model.deploy() + + assert len(mock_get_model_specs.call_args_list) > 1 + + regions = {call[1]["region"] for call in mock_get_model_specs.call_args_list} + + assert len(regions) == 1 + assert list(regions)[0] == region + + s3_clients = {call[1]["s3_client"] for call in mock_get_model_specs.call_args_list} + assert len(s3_clients) == 1 + assert list(s3_clients)[0] == session.s3_client + def test_jumpstart_model_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/test_artifacts.py b/tests/unit/sagemaker/jumpstart/test_artifacts.py index 21112926a5..3d9b5cef6a 100644 --- a/tests/unit/sagemaker/jumpstart/test_artifacts.py +++ b/tests/unit/sagemaker/jumpstart/test_artifacts.py @@ -329,7 +329,8 @@ def test_estimator_fit_kwargs(self, patched_get_model_specs): class RetrieveModelPackageArnTest(unittest.TestCase): - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -439,7 +440,8 @@ def test_retrieve_model_package_arn( class PrivateJumpStartBucketTest(unittest.TestCase): - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index cb54722d48..c81d5639e5 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -15,7 +15,9 @@ from unittest import TestCase from mock.mock import Mock, patch import pytest +import boto3 import random +from sagemaker import session from sagemaker.jumpstart import utils from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -1450,3 +1452,50 @@ def test_logger_disabled(self, mocked_emit: Mock): JUMPSTART_LOGGER.warning("Self destruct in 3...2...1...") mocked_emit.assert_not_called() + + +@pytest.mark.parametrize( + "s3_bucket_name, s3_client, sagemaker_session, region", + [ + ( + "jumpstart-cache-prod", + boto3.client("s3", region_name="blah-blah"), + session.Session(boto3.Session(region_name="blah-blah")), + JUMPSTART_DEFAULT_REGION_NAME, + ), + ( + "jumpstart-cache-prod-us-west-2", + boto3.client("s3", region_name="us-west-2"), + session.Session(boto3.Session(region_name="us-west-2")), + "us-west-2", + ), + ("jumpstart-cache-prod", boto3.client("s3", region_name="us-east-2"), None, "us-east-2"), + ], +) +def test_get_region_fallback_success(s3_bucket_name, s3_client, sagemaker_session, region): + assert region == utils.get_region_fallback(s3_bucket_name, s3_client, sagemaker_session) + + +@pytest.mark.parametrize( + "s3_bucket_name, s3_client, sagemaker_session", + [ + ( + "jumpstart-cache-prod-us-west-2", + boto3.client("s3", region_name="us-east-2"), + session.Session(boto3.Session(region_name="us-west-2")), + ), + ( + "jumpstart-cache-prod-us-west-2", + boto3.client("s3", region_name="us-west-2"), + session.Session(boto3.Session(region_name="eu-north-1")), + ), + ( + "jumpstart-cache-prod-us-west-2-us-east-2", + boto3.client("s3", region_name="us-east-2"), + None, + ), + ], +) +def test_get_region_fallback_failure(s3_bucket_name, s3_client, sagemaker_session): + with pytest.raises(ValueError): + utils.get_region_fallback(s3_bucket_name, s3_client, sagemaker_session) diff --git a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py index 608a32a005..835a09a58c 100644 --- a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py +++ b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py @@ -23,7 +23,8 @@ from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec mock_client = boto3.client("s3") -mock_session = Mock(s3_client=mock_client) +region = "us-west-2" +mock_session = Mock(s3_client=mock_client, boto_region_name=region) @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @@ -36,7 +37,8 @@ def test_jumpstart_default_metric_definitions( patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) model_id = "pytorch-ic-mobilenet-v2" region = "us-west-2" diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 8d75731b06..8ec9478d8a 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -39,7 +39,8 @@ def test_jumpstart_common_model_uri( patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) model_uris.retrieve( model_scope="training", diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index 7b5e7a598d..1c0cfa35b3 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -37,7 +37,7 @@ def test_jumpstart_resource_requirements( patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS region = "us-west-2" mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + mock_session = Mock(s3_client=mock_client, boto_region_name=region) model_id, model_version = "huggingface-llm-mistral-7b-instruct", "*" default_inference_resource_requirements = resource_requirements.retrieve_default( @@ -121,7 +121,7 @@ def test_jumpstart_no_supported_resource_requirements( model_id, model_version = "no-supported-instance-types-model", "*" region = "us-west-2" mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + mock_session = Mock(s3_client=mock_client, boto_region_name=region) default_inference_resource_requirements = resource_requirements.retrieve_default( region=region, diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index c797ba3559..16b7256ed2 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -39,7 +39,8 @@ def test_jumpstart_common_script_uri( patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) script_uris.retrieve( script_scope="training", diff --git a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py index c2253726bf..90ec5df6b5 100644 --- a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py +++ b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py @@ -40,7 +40,7 @@ def test_jumpstart_default_serializers( model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + mock_session = Mock(s3_client=mock_client, boto_region_name=region) default_serializer = serializers.retrieve_default( region=region, @@ -75,7 +75,8 @@ def test_jumpstart_serializer_options( patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" From 5828ad44d0c1ed8be04668ed85ad66d71930f752 Mon Sep 17 00:00:00 2001 From: ruhanprasad <52712386+ruhanprasad@users.noreply.github.com> Date: Wed, 13 Mar 2024 10:23:14 -0700 Subject: [PATCH 09/20] fix: add PT 2.2 support for smdistributed, pytorchddp, and torch_distributed distributions (#4480) * Add support for smdistributed, pytorchddp, torch_distributed for PT 2.2 * formatting * formatting --------- Co-authored-by: liujiaor <128006184+liujiaorr@users.noreply.github.com> --- src/sagemaker/estimator.py | 5 +---- src/sagemaker/fw_utils.py | 1 - tests/unit/test_fw_utils.py | 2 ++ 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 501c826f82..22dd163ede 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -3285,10 +3285,7 @@ class Framework(EstimatorBase): """ _framework_name = None - UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM = ( - "2.0.1-gpu-py310-cu121", - "2.0-gpu-py310-cu121", - ) + UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM = ("2.0.1-gpu-py310-cu121",) def __init__( self, diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index ca5c09a96c..cda94b1c2c 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -161,7 +161,6 @@ "2.2.0", ] - TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = [ "1.13.1", "2.0.0", diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 7fa84acf1a..e955d68227 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -933,6 +933,7 @@ def test_validate_smdataparallel_args_not_raises(): ("ml.p3.16xlarge", "pytorch", "2.0.0", "py310", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "2.0.1", "py310", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "2.1.0", "py310", smdataparallel_enabled), + ("ml.p3.16xlarge", "pytorch", "2.2.0", "py310", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "tensorflow", "2.4.1", "py37", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "tensorflow", "2.4.3", "py3", smdataparallel_enabled_custom_mpi), @@ -957,6 +958,7 @@ def test_validate_smdataparallel_args_not_raises(): ("ml.p3.16xlarge", "pytorch", "2.0.0", "py310", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "pytorch", "2.0.1", "py310", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "pytorch", "2.1.0", "py310", smdataparallel_enabled_custom_mpi), + ("ml.p3.16xlarge", "pytorch", "2.2.0", "py310", smdataparallel_enabled_custom_mpi), ] for instance_type, framework_name, framework_version, py_version, distribution in good_args: fw_utils._validate_smdataparallel_args( From 1cdd446cad0830238631a16c2ecb7005e28bec12 Mon Sep 17 00:00:00 2001 From: ASHWIN KRISHNA <38850354+akrishna1995@users.noreply.github.com> Date: Wed, 13 Mar 2024 11:47:03 -0700 Subject: [PATCH 10/20] change: split coverage out from testenv in tox.ini (#4495) Co-authored-by: Ashwin Krishna --- CONTRIBUTING.md | 4 ++-- tox.ini | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c1ea591249..24226af4ee 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -77,8 +77,8 @@ Before sending us a pull request, please ensure that: 1. Install coverage using `pip install .[test]` 1. cd into the sagemaker-python-sdk folder: `cd sagemaker-python-sdk` or `cd /environment/sagemaker-python-sdk` 1. Run the following tox command and verify that all code checks and unit tests pass: `tox tests/unit` - -You can also run a single test with the following command: `tox -e py310 -- -s -vv ::` +1. You can also run a single test with the following command: `tox -e py310 -- -s -vv ::` +1. You can run coverage via runcvoerage env : `tox -e runcoverage -- tests/unit` or `tox -e py310 -- tests/unit --cov=sagemaker --cov-append --cov-report xml` * Note that the coverage test will fail if you only run a single test, so make sure to surround the command with `export IGNORE_COVERAGE=-` and `unset IGNORE_COVERAGE` * Example: `export IGNORE_COVERAGE=- ; tox -e py310 -- -s -vv tests/unit/test_estimator.py::test_sagemaker_model_s3_uri_invalid ; unset IGNORE_COVERAGE` diff --git a/tox.ini b/tox.ini index d990467b3b..0c1c347c0d 100644 --- a/tox.ini +++ b/tox.ini @@ -86,12 +86,17 @@ commands = pip install 'torchvision==0.15.2+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html' pip install 'dill>=0.3.8' - pytest --cov=sagemaker --cov-append {posargs} -{env:IGNORE_COVERAGE:} coverage report -i --fail-under=86 + pytest {posargs} deps = .[test] depends = {py38,py39,py310,p311}: clean +[testenv:runcoverage] +description = run unit tests with coverage +commands = + pytest --cov=sagemaker --cov-append {posargs} + {env:IGNORE_COVERAGE:} coverage report -i --fail-under=86 + [testenv:flake8] skipdist = true skip_install = true From d15a6399549b745f0d0056fff13756bf40591f1d Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> Date: Wed, 13 Mar 2024 12:04:22 -0700 Subject: [PATCH 11/20] change: add ci-health checks (#4493) --- .github/workflows/codebuild-ci-health.yml | 97 +++++++++++++++++++++++ README.rst | 4 + 2 files changed, 101 insertions(+) create mode 100644 .github/workflows/codebuild-ci-health.yml diff --git a/.github/workflows/codebuild-ci-health.yml b/.github/workflows/codebuild-ci-health.yml new file mode 100644 index 0000000000..246d756857 --- /dev/null +++ b/.github/workflows/codebuild-ci-health.yml @@ -0,0 +1,97 @@ +name: CI Health +on: + schedule: + - cron: "0 */3 * * *" + workflow_dispatch: + +permissions: + id-token: write # This is required for requesting the JWT + +jobs: + codestyle-doc-tests: + runs-on: ubuntu-latest + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} + aws-region: us-west-2 + role-duration-seconds: 10800 + - name: Run Codestyle & Doc Tests + uses: aws-actions/aws-codebuild-run-build@v1 + with: + project-name: sagemaker-python-sdk-ci-health-codestyle-doc-tests + unit-tests: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["py38", "py39", "py310"] + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} + aws-region: us-west-2 + role-duration-seconds: 10800 + - name: Run Unit Tests + uses: aws-actions/aws-codebuild-run-build@v1 + with: + project-name: sagemaker-python-sdk-ci-health-unit-tests + env-vars-for-codebuild: | + PY_VERSION + env: + PY_VERSION: ${{ matrix.python-version }} + integ-tests: + runs-on: ubuntu-latest + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} + aws-region: us-west-2 + role-duration-seconds: 10800 + - name: Run Integ Tests + uses: aws-actions/aws-codebuild-run-build@v1 + id: codebuild + with: + project-name: sagemaker-python-sdk-ci-health-integ-tests + slow-tests: + runs-on: ubuntu-latest + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} + aws-region: us-west-2 + role-duration-seconds: 10800 + - name: Run Slow Tests + uses: aws-actions/aws-codebuild-run-build@v1 + with: + project-name: sagemaker-python-sdk-ci-health-slow-tests + localmode-tests: + runs-on: ubuntu-latest + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} + aws-region: us-west-2 + role-duration-seconds: 10800 + - name: Run Local Mode Tests + uses: aws-actions/aws-codebuild-run-build@v1 + with: + project-name: sagemaker-python-sdk-ci-health-localmode-tests + notebook-tests: + runs-on: ubuntu-latest + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} + aws-region: us-west-2 + role-duration-seconds: 10800 + - name: Run Notebook Tests + uses: aws-actions/aws-codebuild-run-build@v1 + with: + project-name: sagemaker-python-sdk-ci-health-notebook-tests \ No newline at end of file diff --git a/README.rst b/README.rst index 80281bad5a..827bfbfb36 100644 --- a/README.rst +++ b/README.rst @@ -22,6 +22,10 @@ SageMaker Python SDK :target: https://sagemaker.readthedocs.io/en/stable/ :alt: Documentation Status +.. image:: https://github.com/benieric/sagemaker-python-sdk/actions/workflows/codebuild-ci-health.yml/badge.svg + :target: https://github.com/benieric/sagemaker-python-sdk/actions/workflows/codebuild-ci-health.yml + :alt: CI Health + SageMaker Python SDK is an open source library for training and deploying machine learning models on Amazon SageMaker. With the SDK, you can train and deploy models using popular deep learning frameworks **Apache MXNet** and **TensorFlow**. From 15a40ff0e84cba79311d543bcca5a7a42f5eb09a Mon Sep 17 00:00:00 2001 From: evakravi <69981223+evakravi@users.noreply.github.com> Date: Wed, 13 Mar 2024 15:49:40 -0400 Subject: [PATCH 12/20] feat: tgi optimum 0.0.19, 0.0.20 releases (#4496) --- .../huggingface-llm-neuronx.json | 58 +++++++++++++++++++ .../image_uris/test_huggingface_llm.py | 2 + 2 files changed, 60 insertions(+) diff --git a/src/sagemaker/image_uri_config/huggingface-llm-neuronx.json b/src/sagemaker/image_uri_config/huggingface-llm-neuronx.json index 33975042b5..994c746368 100644 --- a/src/sagemaker/image_uri_config/huggingface-llm-neuronx.json +++ b/src/sagemaker/image_uri_config/huggingface-llm-neuronx.json @@ -93,6 +93,64 @@ "container_version": { "inf2": "ubuntu22.04" } + }, + "0.0.19": { + "py_versions": [ + "py310" + ], + "registries": { + "ap-northeast-1": "763104351884", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-4": "457447274322", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "tag_prefix": "1.13.1-optimum0.0.19", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "inf2": "ubuntu22.04" + } + }, + "0.0.20": { + "py_versions": [ + "py310" + ], + "registries": { + "ap-northeast-1": "763104351884", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-4": "457447274322", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "tag_prefix": "1.13.1-optimum0.0.20", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "inf2": "ubuntu22.04" + } } } } diff --git a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py index f8b74a396f..4fc224e1d1 100644 --- a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py +++ b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py @@ -35,6 +35,8 @@ "0.0.16": "1.13.1-optimum0.0.16-neuronx-py310-ubuntu22.04", "0.0.17": "1.13.1-optimum0.0.17-neuronx-py310-ubuntu22.04", "0.0.18": "1.13.1-optimum0.0.18-neuronx-py310-ubuntu22.04", + "0.0.19": "1.13.1-optimum0.0.19-neuronx-py310-ubuntu22.04", + "0.0.20": "1.13.1-optimum0.0.20-neuronx-py310-ubuntu22.04", }, } From fada4bf27340365f8ff048675bf91987ae64ceae Mon Sep 17 00:00:00 2001 From: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> Date: Thu, 14 Mar 2024 10:40:08 -0700 Subject: [PATCH 13/20] feature: Add support for Streaming Inference (#4497) * feature: Add support for Streaming Inference * fix: codestyle-docs-test * fix: codestyle-docs-test --- src/sagemaker/base_predictor.py | 82 ++++++++ src/sagemaker/exceptions.py | 20 ++ src/sagemaker/iterators.py | 186 ++++++++++++++++++ src/sagemaker/jumpstart/cache.py | 4 +- .../lmi-model-falcon-7b/mymodel-7B.tar.gz | Bin 0 -> 382 bytes tests/integ/test_predict_stream.py | 109 ++++++++++ .../sagemaker/iterators/test_iterators.py | 126 ++++++++++++ tests/unit/test_predictor.py | 75 +++++++ 8 files changed, 601 insertions(+), 1 deletion(-) create mode 100644 src/sagemaker/iterators.py create mode 100644 tests/data/lmi-model-falcon-7b/mymodel-7B.tar.gz create mode 100644 tests/integ/test_predict_stream.py create mode 100644 tests/unit/sagemaker/iterators/test_iterators.py diff --git a/src/sagemaker/base_predictor.py b/src/sagemaker/base_predictor.py index 76b83c25cd..1a7eea9cd7 100644 --- a/src/sagemaker/base_predictor.py +++ b/src/sagemaker/base_predictor.py @@ -52,6 +52,7 @@ JSONSerializer, NumpySerializer, ) +from sagemaker.iterators import ByteIterator from sagemaker.session import production_variant, Session from sagemaker.utils import name_from_base, stringify_object, format_tags @@ -225,6 +226,7 @@ def _create_request_args( target_variant=None, inference_id=None, custom_attributes=None, + target_container_hostname=None, ): """Placeholder docstring""" @@ -286,9 +288,89 @@ def _create_request_args( if self._get_component_name(): args["InferenceComponentName"] = self.component_name + if target_container_hostname: + args["TargetContainerHostname"] = target_container_hostname + args["Body"] = data return args + def predict_stream( + self, + data, + initial_args=None, + target_variant=None, + inference_id=None, + custom_attributes=None, + component_name: Optional[str] = None, + target_container_hostname=None, + iterator=ByteIterator, + ): + """Return the inference from the specified endpoint. + + Args: + data (object): Input data for which you want the model to provide + inference. If a serializer was specified when creating the + Predictor, the result of the serializer is sent as input + data. Otherwise the data must be sequence of bytes, and the + predict method then sends the bytes in the request body as is. + initial_args (dict[str,str]): Optional. Default arguments for boto3 + ``invoke_endpoint_with_response_stream`` call. Default is None (no default + arguments). (Default: None) + target_variant (str): Optional. The name of the production variant to run an inference + request on (Default: None). Note that the ProductionVariant identifies the + model you want to host and the resources you want to deploy for hosting it. + inference_id (str): Optional. If you provide a value, it is added to the captured data + when you enable data capture on the endpoint (Default: None). + custom_attributes (str): Optional. Provides additional information about a request for + an inference submitted to a model hosted at an Amazon SageMaker endpoint. + The information is an opaque value that is forwarded verbatim. You could use this + value, for example, to provide an ID that you can use to track a request or to + provide other metadata that a service endpoint was programmed to process. The value + must consist of no more than 1024 visible US-ASCII characters. + + The code in your model is responsible for setting or updating any custom attributes + in the response. If your code does not set this value in the response, an empty + value is returned. For example, if a custom attribute represents the trace ID, your + model can prepend the custom attribute with Trace ID: in your post-processing + function (Default: None). + component_name (str): Optional. Name of the Amazon SageMaker inference component + corresponding the predictor. (Default: None) + target_container_hostname (str): Optional. If the endpoint hosts multiple containers + and is configured to use direct invocation, this parameter specifies the host name + of the container to invoke. (Default: None). + iterator (:class:`~sagemaker.iterators.BaseIterator`): An iterator class which provides + an iterable interface to iterate Event stream response from Inference Endpoint. + An object of the iterator class provided will be returned by the predict_stream + method (Default::class:`~sagemaker.iterators.ByteIterator`). Iterators defined in + :class:`~sagemaker.iterators` or custom iterators (needs to inherit + :class:`~sagemaker.iterators.BaseIterator`) can be specified as an input. + + Returns: + object (:class:`~sagemaker.iterators.BaseIterator`): An iterator object which would + allow iteration on EventStream response will be returned. The object would be + instantiated from `predict_stream` method's `iterator` parameter. + """ + # [TODO]: clean up component_name in _create_request_args + request_args = self._create_request_args( + data=data, + initial_args=initial_args, + target_variant=target_variant, + inference_id=inference_id, + custom_attributes=custom_attributes, + target_container_hostname=target_container_hostname, + ) + + inference_component_name = component_name or self._get_component_name() + if inference_component_name: + request_args["InferenceComponentName"] = inference_component_name + + response = ( + self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint_with_response_stream( + **request_args + ) + ) + return iterator(response["Body"]) + def update_endpoint( self, initial_instance_count=None, diff --git a/src/sagemaker/exceptions.py b/src/sagemaker/exceptions.py index b9d97cc241..88ffa0a591 100644 --- a/src/sagemaker/exceptions.py +++ b/src/sagemaker/exceptions.py @@ -86,3 +86,23 @@ class AsyncInferenceModelError(AsyncInferenceError): def __init__(self, message): super().__init__(message=message) + + +class ModelStreamError(Exception): + """Raised when invoke_endpoint_with_response_stream Response returns ModelStreamError""" + + def __init__(self, message="An error occurred", code=None): + self.message = message + self.code = code + if code is not None: + super().__init__(f"{message} (Code: {code})") + else: + super().__init__(message) + + +class InternalStreamFailure(Exception): + """Raised when invoke_endpoint_with_response_stream Response returns InternalStreamFailure""" + + def __init__(self, message="An error occurred"): + self.message = message + super().__init__(self.message) diff --git a/src/sagemaker/iterators.py b/src/sagemaker/iterators.py new file mode 100644 index 0000000000..38a43121a1 --- /dev/null +++ b/src/sagemaker/iterators.py @@ -0,0 +1,186 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Implements iterators for deserializing data returned from an inference streaming endpoint.""" +from __future__ import absolute_import + +from abc import ABC, abstractmethod +import io + +from sagemaker.exceptions import ModelStreamError, InternalStreamFailure + + +def handle_stream_errors(chunk): + """Handle API Response errors within `invoke_endpoint_with_response_stream` API if any. + + Args: + chunk (dict): A chunk of response received as part of `botocore.eventstream.EventStream` + response object. + + Raises: + ModelStreamError: If `ModelStreamError` error is detected in a chunk of + `botocore.eventstream.EventStream` response object. + InternalStreamFailure: If `InternalStreamFailure` error is detected in a chunk of + `botocore.eventstream.EventStream` response object. + """ + if "ModelStreamError" in chunk: + raise ModelStreamError( + chunk["ModelStreamError"]["Message"], code=chunk["ModelStreamError"]["ErrorCode"] + ) + if "InternalStreamFailure" in chunk: + raise InternalStreamFailure(chunk["InternalStreamFailure"]["Message"]) + + +class BaseIterator(ABC): + """Abstract base class for Inference Streaming iterators. + + Provides a skeleton for customization requiring the overriding of iterator methods + __iter__ and __next__. + + Tenets of iterator class for Streaming Inference API Response + (https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ + sagemaker-runtime/client/invoke_endpoint_with_response_stream.html): + 1. Needs to accept an botocore.eventstream.EventStream response. + 2. Needs to implement logic in __next__ to: + 2.1. Concatenate and provide next chunk of response from botocore.eventstream.EventStream. + While doing so parse the response_chunk["PayloadPart"]["Bytes"]. + 2.2. If PayloadPart not in EventStream response, handle Errors + [Recommended to use `iterators.handle_stream_errors` method]. + """ + + def __init__(self, event_stream): + """Initialises a Iterator object to help parse the byte event stream input. + + Args: + event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated. + """ + self.event_stream = event_stream + + @abstractmethod + def __iter__(self): + """Abstract method, returns an iterator object itself""" + return self + + @abstractmethod + def __next__(self): + """Abstract method, is responsible for returning the next element in the iteration""" + + +class ByteIterator(BaseIterator): + """A helper class for parsing the byte Event Stream input to provide Byte iteration.""" + + def __init__(self, event_stream): + """Initialises a BytesIterator Iterator object + + Args: + event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated. + """ + super().__init__(event_stream) + self.byte_iterator = iter(event_stream) + + def __iter__(self): + """Returns an iterator object itself, which allows the object to be iterated. + + Returns: + iter : object + An iterator object representing the iterable. + """ + return self + + def __next__(self): + """Returns the next chunk of Byte directly.""" + # Even with "while True" loop the function still behaves like a generator + # and sends the next new byte chunk. + while True: + chunk = next(self.byte_iterator) + if "PayloadPart" not in chunk: + # handle API response errors and force terminate. + handle_stream_errors(chunk) + # print and move on to next response byte + print("Unknown event type:" + chunk) + continue + return chunk["PayloadPart"]["Bytes"] + + +class LineIterator(BaseIterator): + """A helper class for parsing the byte Event Stream input to provide Line iteration.""" + + def __init__(self, event_stream): + """Initialises a LineIterator Iterator object + + Args: + event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated. + """ + super().__init__(event_stream) + self.byte_iterator = iter(self.event_stream) + self.buffer = io.BytesIO() + self.read_pos = 0 + + def __iter__(self): + """Returns an iterator object itself, which allows the object to be iterated. + + Returns: + iter : object + An iterator object representing the iterable. + """ + return self + + def __next__(self): + r"""Returns the next Line for an Line iterable. + + The output of the event stream will be in the following format: + + ``` + b'{"outputs": [" a"]}\n' + b'{"outputs": [" challenging"]}\n' + b'{"outputs": [" problem"]}\n' + ... + ``` + + While usually each PayloadPart event from the event stream will contain a byte array + with a full json, this is not guaranteed and some of the json objects may be split across + PayloadPart events. For example: + ``` + {'PayloadPart': {'Bytes': b'{"outputs": '}} + {'PayloadPart': {'Bytes': b'[" problem"]}\n'}} + ``` + + This class accounts for this by concatenating bytes written via the 'write' function + and then exposing a method which will return lines (ending with a '\n' character) within + the buffer via the 'scan_lines' function. It maintains the position of the last read + position to ensure that previous bytes are not exposed again. + + Returns: + str: Read and return one line from the event stream. + """ + # Even with "while True" loop the function still behaves like a generator + # and sends the next new concatenated line + while True: + self.buffer.seek(self.read_pos) + line = self.buffer.readline() + if line and line[-1] == ord("\n"): + self.read_pos += len(line) + return line[:-1] + try: + chunk = next(self.byte_iterator) + except StopIteration: + if self.read_pos < self.buffer.getbuffer().nbytes: + continue + raise + if "PayloadPart" not in chunk: + # handle API response errors and force terminate. + handle_stream_errors(chunk) + # print and move on to next response byte + print("Unknown event type:" + chunk) + continue + self.buffer.seek(0, io.SEEK_END) + self.buffer.write(chunk["PayloadPart"]["Bytes"]) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index fff421ab32..e9a34a21a8 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -65,7 +65,9 @@ def __init__( self, region: Optional[str] = None, max_s3_cache_items: int = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, - s3_cache_expiration_horizon: datetime.timedelta = JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON, + s3_cache_expiration_horizon: datetime.timedelta = ( + JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON + ), max_semantic_version_cache_items: int = JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, semantic_version_cache_expiration_horizon: datetime.timedelta = ( JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON diff --git a/tests/data/lmi-model-falcon-7b/mymodel-7B.tar.gz b/tests/data/lmi-model-falcon-7b/mymodel-7B.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..6a66178b47a086851a140b8dba294307ac31ace1 GIT binary patch literal 382 zcmV-^0fGJ>iwFP!000001MSkyPV68Q2k@?aioU=oOg|VGY}~mp@eK$qr*knaG(e5| z^fF_nPHu8_qcJ!Be;Zmj^qg{-o+oc;+=!d2;=8a+G|h3${vMCdylzC*3NGrlV4OF+ zF3RTHDmt^oq(fO2!Ta=4+-K|msp-A{k;0>O`^!1_nL@G@zbMC{!EIgtv;R$1o%KK8 z6W&xz6eatj{2%(|{U^7#j^y3_?S-F{_3rX`ACxsRS-WVu8uZwEw-MdOx|qV!r&DC0 zM;r5lq^{;{=-Oe>*KE6^YoSfIV|-wyrM_+r+YY;qiPOgXm6%kZ$tO~M&L{H>t*hjs z4{Fvyk7F*y&^{1Jz80vTRPf`N@2cu_>i?){Ur1KlwXX9;IZk$CY+S4MOPZIY1|KG! z5(W7Xz02_wPZ1_P&s55Cn0b4eoAsWII&5% Date: Thu, 14 Mar 2024 19:54:54 +0100 Subject: [PATCH 14/20] Add AutoML -> AutoMLV2 mapper (#4500) Co-authored-by: liujiaor <128006184+liujiaorr@users.noreply.github.com> --- src/sagemaker/automl/automlv2.py | 44 ++++++++++++++++++- .../unit/sagemaker/automl/test_auto_ml_v2.py | 27 ++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/automl/automlv2.py b/src/sagemaker/automl/automlv2.py index c855414f0b..8b34f54a95 100644 --- a/src/sagemaker/automl/automlv2.py +++ b/src/sagemaker/automl/automlv2.py @@ -11,13 +11,15 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """A class for SageMaker AutoML V2 Jobs.""" -from __future__ import absolute_import + +from __future__ import absolute_import, annotations import logging from dataclasses import dataclass from typing import Dict, List, Optional, Union from sagemaker import Model, PipelineModel, s3 +from sagemaker.automl.automl import AutoML from sagemaker.automl.candidate_estimator import CandidateEstimator from sagemaker.config import ( AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH, @@ -727,6 +729,46 @@ def __init__( self._auto_ml_job_desc = None self._best_candidate = None + @classmethod + def from_auto_ml(cls, auto_ml: AutoML) -> AutoMLV2: + """Create an AutoMLV2 object from an AutoML object. + + This method maps AutoML properties into an AutoMLV2 object, + so you can create AutoMLV2 jobs from the existing AutoML objects. + + Args: + auto_ml (sagemaker.automl.automl.AutoML): An AutoML object from which + an AutoMLV2 object will be created. + """ + auto_ml_v2 = AutoMLV2( + problem_config=AutoMLTabularConfig( + target_attribute_name=auto_ml.target_attribute_name, + feature_specification_s3_uri=auto_ml.feature_specification_s3_uri, + generate_candidate_definitions_only=auto_ml.generate_candidate_definitions_only, + mode=auto_ml.mode, + problem_type=auto_ml.problem_type, + sample_weight_attribute_name=auto_ml.sample_weight_attribute_name, + max_candidates=auto_ml.max_candidate, + max_runtime_per_training_job_in_seconds=auto_ml.max_runtime_per_training_job_in_seconds, # noqa E501 # pylint: disable=c0301 + max_total_job_runtime_in_seconds=auto_ml.total_job_runtime_in_seconds, + ), + base_job_name=auto_ml.base_job_name, + output_path=auto_ml.output_path, + output_kms_key=auto_ml.output_kms_key, + job_objective=auto_ml.job_objective, + validation_fraction=auto_ml.validation_fraction, + auto_generate_endpoint_name=auto_ml.auto_generate_endpoint_name, + endpoint_name=auto_ml.endpoint_name, + role=auto_ml.role, + volume_kms_key=auto_ml.volume_kms_key, + encrypt_inter_container_traffic=auto_ml.encrypt_inter_container_traffic, + vpc_config=auto_ml.vpc_config, + tags=auto_ml.tags, + sagemaker_session=auto_ml.sagemaker_session, + ) + auto_ml_v2._best_candidate = auto_ml._best_candidate + return auto_ml_v2 + def fit( self, inputs: Optional[ diff --git a/tests/unit/sagemaker/automl/test_auto_ml_v2.py b/tests/unit/sagemaker/automl/test_auto_ml_v2.py index 94d87b0a8e..3b1bfa76ed 100644 --- a/tests/unit/sagemaker/automl/test_auto_ml_v2.py +++ b/tests/unit/sagemaker/automl/test_auto_ml_v2.py @@ -24,6 +24,7 @@ CandidateEstimator, LocalAutoMLDataChannel, PipelineModel, + AutoML, ) from sagemaker.predictor import Predictor from sagemaker.session_settings import SessionSettings @@ -1100,3 +1101,29 @@ def without_user_input(sess): expected__with_user_input__with_default_bucket_only="s3://test", ) assert actual == expected + + +def test_automl_v1_to_automl_v2_mapping(): + auto_ml = AutoML( + role=ROLE, + target_attribute_name=TARGET_ATTRIBUTE_NAME, + sample_weight_attribute_name=SAMPLE_WEIGHT_ATTRIBUTE_NAME, + output_kms_key=OUTPUT_KMS_KEY, + output_path=OUTPUT_PATH, + max_candidates=MAX_CANDIDATES, + base_job_name=BASE_JOB_NAME, + ) + + auto_ml_v2 = AutoMLV2.from_auto_ml(auto_ml=auto_ml) + + assert isinstance(auto_ml_v2.problem_config, AutoMLTabularConfig) + assert auto_ml_v2.role == auto_ml.role + assert auto_ml_v2.problem_config.target_attribute_name == auto_ml.target_attribute_name + assert ( + auto_ml_v2.problem_config.sample_weight_attribute_name + == auto_ml.sample_weight_attribute_name + ) + assert auto_ml_v2.output_kms_key == auto_ml.output_kms_key + assert auto_ml_v2.output_path == auto_ml.output_path + assert auto_ml_v2.problem_config.max_candidates == auto_ml.max_candidate + assert auto_ml_v2.base_job_name == auto_ml.base_job_name From b9fbfbd81a39aeae431265ae1cf0fbbd253ffd2a Mon Sep 17 00:00:00 2001 From: Anton Repushko Date: Thu, 14 Mar 2024 22:38:55 +0100 Subject: [PATCH 15/20] Skip of tests which are long running and causing the ResourceLimitInUse exception (#4504) --- tests/integ/test_auto_ml_v2.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/integ/test_auto_ml_v2.py b/tests/integ/test_auto_ml_v2.py index 91802770d1..c42ff3432e 100644 --- a/tests/integ/test_auto_ml_v2.py +++ b/tests/integ/test_auto_ml_v2.py @@ -50,6 +50,10 @@ def test_time_series_forecasting_session_job_name(): return unique_name_from_base("ts-forecast-job", max_length=32) +@pytest.mark.skip( + reason="The test is disabled because it's causing the ResourceLimit exception. \ +Please run that manually before the proper fix." +) @pytest.mark.skipif( tests.integ.test_region() in tests.integ.NO_AUTO_ML_REGIONS, reason="AutoML is not supported in the region yet.", @@ -116,6 +120,10 @@ def test_auto_ml_v2_describe_auto_ml_job( assert desc["OutputDataConfig"] == expected_default_output_config +@pytest.mark.skip( + reason="The test is disabled because it's causing the ResourceLimit exception. \ +Please run that manually before the proper fix." +) @pytest.mark.skipif( tests.integ.test_region() in tests.integ.NO_AUTO_ML_REGIONS, reason="AutoML is not supported in the region yet.", @@ -181,6 +189,10 @@ def test_auto_ml_v2_attach(problem_type, job_name_fixture_key, sagemaker_session assert desc["OutputDataConfig"] == expected_default_output_config +@pytest.mark.skip( + reason="The test is disabled because it's causing the ResourceLimit exception. \ +Please run that manually before the proper fix." +) @pytest.mark.skipif( tests.integ.test_region() in tests.integ.NO_AUTO_ML_REGIONS, reason="AutoML is not supported in the region yet.", @@ -258,6 +270,10 @@ def test_list_candidates( pytest.skip("The job hasn't finished yet") +@pytest.mark.skip( + reason="The test is disabled because it's causing the ResourceLimit exception. \ +Please run that manually before the proper fix." +) @pytest.mark.skipif( tests.integ.test_region() in tests.integ.NO_AUTO_ML_REGIONS, reason="AutoML is not supported in the region yet.", @@ -329,6 +345,10 @@ def test_best_candidate( pytest.skip("The job hasn't finished yet") +@pytest.mark.skip( + reason="The test is disabled because it's causing the ResourceLimit exception. \ +Please run that manually before the proper fix." +) @pytest.mark.skipif( tests.integ.test_region() in tests.integ.NO_AUTO_ML_REGIONS or tests.integ.test_region() in tests.integ.NO_CANVAS_REGIONS, @@ -423,6 +443,10 @@ def test_deploy_best_candidate( pytest.skip("The job hasn't finished yet") +@pytest.mark.skip( + reason="The test is disabled because it's causing the ResourceLimit exception. \ +Please run that manually before the proper fix." +) @pytest.mark.skipif( tests.integ.test_region() in tests.integ.NO_AUTO_ML_REGIONS, reason="AutoML is not supported in the region yet.", From c2d5a2358c79f9b74b1b2e6de566d4a71c5ad593 Mon Sep 17 00:00:00 2001 From: Anton Repushko Date: Fri, 15 Mar 2024 16:49:29 +0100 Subject: [PATCH 16/20] Improvement of the tuner documentation (#4506) --- src/sagemaker/tuner.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 571f84761f..967bff1b99 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Placeholder docstring""" + from __future__ import absolute_import import importlib @@ -641,8 +642,11 @@ def __init__( extract the metric from the logs. This should be defined only for hyperparameter tuning jobs that don't use an Amazon algorithm. - strategy (str or PipelineVariable): Strategy to be used for hyperparameter estimations - (default: 'Bayesian'). + strategy (str or PipelineVariable): Strategy to be used for hyperparameter estimations. + More information about different strategies: + https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-how-it-works.html. + Available options are: 'Bayesian', 'Random', 'Hyperband', + 'Grid' (default: 'Bayesian') objective_type (str or PipelineVariable): The type of the objective metric for evaluating training jobs. This value can be either 'Minimize' or 'Maximize' (default: 'Maximize'). @@ -759,7 +763,8 @@ def __init__( self.autotune = autotune def override_resource_config( - self, instance_configs: Union[List[InstanceConfig], Dict[str, List[InstanceConfig]]] + self, + instance_configs: Union[List[InstanceConfig], Dict[str, List[InstanceConfig]]], ): """Override the instance configuration of the estimators used by the tuner. @@ -966,7 +971,7 @@ def fit( include_cls_metadata: Union[bool, Dict[str, bool]] = False, estimator_kwargs: Optional[Dict[str, dict]] = None, wait: bool = True, - **kwargs + **kwargs, ): """Start a hyperparameter tuning job. @@ -1055,7 +1060,7 @@ def _fit_with_estimator_dict(self, inputs, job_name, include_cls_metadata, estim allowed_keys=estimator_names, ) - for (estimator_name, estimator) in self.estimator_dict.items(): + for estimator_name, estimator in self.estimator_dict.items(): ins = inputs.get(estimator_name, None) if inputs is not None else None args = estimator_kwargs.get(estimator_name, {}) if estimator_kwargs is not None else {} self._prepare_estimator_for_tuning(estimator, ins, job_name, **args) @@ -1282,7 +1287,7 @@ def _attach_with_training_details_list(cls, sagemaker_session, estimator_cls, jo objective_metric_name_dict=objective_metric_name_dict, hyperparameter_ranges_dict=hyperparameter_ranges_dict, metric_definitions_dict=metric_definitions_dict, - **init_params + **init_params, ) def deploy( @@ -1297,7 +1302,7 @@ def deploy( model_name=None, kms_key=None, data_capture_config=None, - **kwargs + **kwargs, ): """Deploy the best trained or user specified model to an Amazon SageMaker endpoint. @@ -1363,7 +1368,7 @@ def deploy( model_name=model_name, kms_key=kms_key, data_capture_config=data_capture_config, - **kwargs + **kwargs, ) def stop_tuning_job(self): From 567800435af89bbd546cfe7ca1fdcb0911e7314f Mon Sep 17 00:00:00 2001 From: ci Date: Fri, 15 Mar 2024 17:31:55 +0000 Subject: [PATCH 17/20] prepare release v2.213.0 --- CHANGELOG.md | 26 ++++++++++++++++++++++++++ VERSION | 2 +- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 70ae3538a4..f8d1e70116 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,31 @@ # Changelog +## v2.213.0 (2024-03-15) + +### Features + + * Add support for Streaming Inference + * tgi optimum 0.0.19, 0.0.20 releases + * support JumpStart proprietary models + * Add ModelDataSource and SourceUri support for model package and while registering + * Accept user-defined env variables for the entry-point + * Add overriding logic in ModelBuilder when task is provided + +### Bug Fixes and Other Changes + + * Improvement of the tuner documentation + * Skip of tests which are long running and causing the ResourceLimitInUse exception + * Add AutoML -> AutoMLV2 mapper + * add ci-health checks + * split coverage out from testenv in tox.ini + * add PT 2.2 support for smdistributed, pytorchddp, and torch_distributed distributions + * sagemaker session region not being used + * chore: emit warning when no instance specific gated training env var is available, and raise exception when accept_eula flag is not supplied + * enable github actions for PRs + * Move sagemaker pysdk version check after bootstrap in remote job + * make unit tests compatible with pytest-xdist + * Update tblib constraint + ## v2.212.0 (2024-03-06) ### Features diff --git a/VERSION b/VERSION index 4e29bf93f0..f654d86d26 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.212.1.dev0 +2.213.0 From 09fe1c665677d3e891c25efef07a224fdaa74ab2 Mon Sep 17 00:00:00 2001 From: ci Date: Fri, 15 Mar 2024 17:31:57 +0000 Subject: [PATCH 18/20] update development version to v2.213.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index f654d86d26..e88edc395f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.213.0 +2.213.1.dev0 From 5a7e99e3f3b3d2ab5847b18703505a747fa8bc03 Mon Sep 17 00:00:00 2001 From: ASHWIN KRISHNA <38850354+akrishna1995@users.noreply.github.com> Date: Fri, 15 Mar 2024 14:46:50 -0700 Subject: [PATCH 19/20] fix:urge customers to install latest version (#4507) --- README.rst | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index 827bfbfb36..18b236a5e8 100644 --- a/README.rst +++ b/README.rst @@ -67,11 +67,10 @@ Table of Contents Installing the SageMaker Python SDK ----------------------------------- -The SageMaker Python SDK is built to PyPI and can be installed with pip as follows: - +The SageMaker Python SDK is built to PyPI and the latest version of the SageMaker Python SDK can be installed with pip as follows :: - pip install sagemaker + pip install sagemaker== You can install from source by cloning this repository and running a pip install command in the root directory of the repository: From 434cba06a5f8367a7f2a2df800d6ae4ef14cedef Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Mon, 18 Mar 2024 17:49:23 -0400 Subject: [PATCH 20/20] fix: list jumpstart models with invalid version strings (#4511) * fix: list jumpstart models with invalid versions * docstyle * docstyle * pylint * add more test * fix --- src/sagemaker/jumpstart/notebook_utils.py | 12 +++++++++++- tests/unit/sagemaker/jumpstart/constants.py | 14 ++++++++++++++ .../sagemaker/jumpstart/test_notebook_utils.py | 9 +++++++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 85a041379a..9df744531e 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -262,6 +262,15 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin return sorted(list(scripts)) +def _is_valid_version(version: str) -> bool: + """Checks if the version is convertable to Version class.""" + try: + Version(version) + return True + except Exception: # pylint: disable=broad-except + return False + + def list_jumpstart_models( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), region: Optional[str] = None, @@ -304,7 +313,8 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin ): if model_id not in model_id_version_dict: model_id_version_dict[model_id] = list() - model_id_version_dict[model_id].append(Version(version)) + model_version = Version(version) if _is_valid_version(version) else version + model_id_version_dict[model_id].append(model_version) if not list_versions: return sorted(list(model_id_version_dict.keys())) diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index d7c4eb4921..1ea08724b9 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -7577,6 +7577,20 @@ "spec_key": "proprietary-models/ai21-paraphrase/proprietary_specs_1.0.005.json", "search_keywords": ["Text2Text", "Generation"], }, + { + "model_id": "ai21-paraphrase", + "version": "v1.00-rc2-not-valid-version", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-paraphrase/proprietary_specs_1.0.005.json", + "search_keywords": ["Text2Text", "Generation"], + }, + { + "model_id": "nc-soft-model-1", + "version": "v3.0-not-valid-version!", + "min_version": "2.0.0", + "spec_key": "proprietary-models/nc-soft-model-1/proprietary_specs_1.0.005.json", + "search_keywords": ["Text2Text", "Generation"], + }, ] BASE_PROPRIETARY_SPEC = { diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 059cd7ccad..862d2b4174 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -25,6 +25,7 @@ list_jumpstart_models, list_jumpstart_scripts, list_jumpstart_tasks, + _is_valid_version, ) @@ -185,6 +186,13 @@ def test_list_jumpstart_frameworks( patched_get_model_specs.assert_not_called() +def test_is_valid_version(): + valid_version_strs = ["1.0", "1.0.0", "2012.4", "1!1.0", "1.dev0", "1.2.3+abc.dev1"] + invalid_version_strs = ["1.1.053_m", "invalid version", "v1-1.0-v2", "@"] + assert all(_is_valid_version(v) for v in valid_version_strs) + assert not any(_is_valid_version(v) for v in invalid_version_strs) + + class ListJumpStartModels(TestCase): @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -626,6 +634,7 @@ def test_list_jumpstart_proprietary_models( "ai21-paraphrase", "ai21-summarization", "lighton-mini-instruct40b", + "nc-soft-model-1", ] all_open_weight_model_ids = [