From 2b29f866706a91422f04ec692e71d1dbfda1eb9e Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 28 Feb 2024 19:54:42 +0000 Subject: [PATCH 01/42] prepare release v2.210.0 --- CHANGELOG.md | 15 +++++++++++++++ VERSION | 2 +- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 86c0a3d426..bcc19b6b22 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,20 @@ # Changelog +## v2.210.0 (2024-02-28) + +### Features + + * Prepend SageMaker Studio App Type to boto3 User Agent string + * TGI optimum 0.0.18 (general+llm) + * TGI 1.4.2 + +### Bug Fixes and Other Changes + + * tolerate vulnerable old model for integ test and temporarily skip test_list_jumpstart_models_script_filter + * add missing regions to pytorch config + * Add validation for sagemaker version on remote job + * fixed implementation of fail_on_violation for transform with monitoring + ## v2.209.0 (2024-02-24) ### Features diff --git a/VERSION b/VERSION index e19556b3c5..164470511a 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.209.1.dev0 +2.210.0 From c6e93f91dfb17b545cf62de6ec5265f59baa6aba Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 28 Feb 2024 19:54:44 +0000 Subject: [PATCH 02/42] update development version to v2.210.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 164470511a..fdf99214f7 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.210.0 +2.210.1.dev0 From 2f1bed027f7c7f8fee5a33200fa626629bce77e0 Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Wed, 28 Feb 2024 12:37:38 -0800 Subject: [PATCH 03/42] feat: Add new Triton DLC URIs (#4432) * Add new Triton DLC URIs * Update according to black and pylint --- .../sagemaker-tritonserver.json | 75 +++++++++++++++++++ src/sagemaker/image_uris.py | 6 ++ .../sagemaker/image_uris/expected_uris.py | 7 ++ .../image_uris/test_sagemaker_tritonserver.py | 55 ++++++++++++++ 4 files changed, 143 insertions(+) create mode 100644 src/sagemaker/image_uri_config/sagemaker-tritonserver.json create mode 100644 tests/unit/sagemaker/image_uris/test_sagemaker_tritonserver.py diff --git a/src/sagemaker/image_uri_config/sagemaker-tritonserver.json b/src/sagemaker/image_uri_config/sagemaker-tritonserver.json new file mode 100644 index 0000000000..82397d913e --- /dev/null +++ b/src/sagemaker/image_uri_config/sagemaker-tritonserver.json @@ -0,0 +1,75 @@ +{ + "processors": [ + "cpu", + "gpu" + ], + "scope": [ + "inference" + ], + "versions": { + "23.12": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "sagemaker-tritonserver", + "tag_prefix": "23.12-py3" + }, + "24.01": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "sagemaker-tritonserver", + "tag_prefix": "24.01-py3" + } + } +} \ No newline at end of file diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 2b6870a11c..99692d0f8b 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -44,6 +44,7 @@ INFERENCE_GRAVITON = "inference_graviton" DATA_WRANGLER_FRAMEWORK = "data-wrangler" STABILITYAI_FRAMEWORK = "stabilityai" +SAGEMAKER_TRITONSERVER_FRAMEWORK = "sagemaker-tritonserver" @override_pipeline_parameter_var @@ -339,6 +340,11 @@ def _get_image_tag( if key in container_versions: tag = "-".join([tag, container_versions[key]]) + # Triton images don't have a trailing -gpu tag. Only -cpu images do. + if framework == SAGEMAKER_TRITONSERVER_FRAMEWORK: + if processor == "gpu": + tag = tag.rstrip("-gpu") + return tag diff --git a/tests/unit/sagemaker/image_uris/expected_uris.py b/tests/unit/sagemaker/image_uris/expected_uris.py index 438a00a038..094323ef0b 100644 --- a/tests/unit/sagemaker/image_uris/expected_uris.py +++ b/tests/unit/sagemaker/image_uris/expected_uris.py @@ -84,6 +84,13 @@ def djl_framework_uri(repo, account, tag, region=REGION): return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag) +def sagemaker_triton_framework_uri(repo, account, tag, processor="gpu", region=REGION): + domain = ALTERNATE_DOMAINS.get(region, DOMAIN) + if processor == "cpu": + tag = f"{tag}-cpu" + return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag) + + def huggingface_llm_framework_uri( repo, account, diff --git a/tests/unit/sagemaker/image_uris/test_sagemaker_tritonserver.py b/tests/unit/sagemaker/image_uris/test_sagemaker_tritonserver.py new file mode 100644 index 0000000000..7dd75fc3b8 --- /dev/null +++ b/tests/unit/sagemaker/image_uris/test_sagemaker_tritonserver.py @@ -0,0 +1,55 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +from sagemaker import image_uris +from tests.unit.sagemaker.image_uris import expected_uris + +INSTANCE_TYPES = {"cpu": "ml.c4.xlarge", "gpu": "ml.p2.xlarge"} + + +@pytest.mark.parametrize( + "load_config_and_file_name", + ["sagemaker-tritonserver.json"], + indirect=True, +) +def test_sagemaker_tritonserver_uris(load_config_and_file_name): + config, file_name = load_config_and_file_name + framework = file_name.split(".json")[0] + VERSIONS = config["versions"] + processors = config["processors"] + for version in VERSIONS: + ACCOUNTS = config["versions"][version]["registries"] + tag = config["versions"][version]["tag_prefix"] + for processor in processors: + instance_type = INSTANCE_TYPES[processor] + for region in ACCOUNTS.keys(): + _test_sagemaker_tritonserver_uris( + ACCOUNTS[region], region, version, tag, framework, instance_type, processor + ) + + +def _test_sagemaker_tritonserver_uris( + account, region, version, tag, triton_framework, instance_type, processor +): + uri = image_uris.retrieve( + framework=triton_framework, region=region, version=version, instance_type=instance_type + ) + expected = expected_uris.sagemaker_triton_framework_uri( + "sagemaker-tritonserver", + account, + tag, + processor, + region, + ) + assert expected == uri From bb48c731b800e39c74817d3f26ad7427e8f3ebcd Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Thu, 29 Feb 2024 11:39:54 -0800 Subject: [PATCH 04/42] feat: Support selective pipeline execution between function step and regular step (#4392) --- src/sagemaker/workflow/function_step.py | 15 +-- src/sagemaker/workflow/functions.py | 8 +- src/sagemaker/workflow/pipeline.py | 27 +---- tests/integ/sagemaker/workflow/helpers.py | 4 +- .../workflow/test_selective_execution.py | 110 +++++++++++++++++- .../sagemaker/workflow/test_step_decorator.py | 28 ++--- .../sagemaker/workflow/test_condition_step.py | 6 +- .../sagemaker/workflow/test_function_step.py | 24 +--- .../unit/sagemaker/workflow/test_functions.py | 6 +- .../unit/sagemaker/workflow/test_pipeline.py | 17 --- 10 files changed, 137 insertions(+), 108 deletions(-) diff --git a/src/sagemaker/workflow/function_step.py b/src/sagemaker/workflow/function_step.py index 55e7eac90c..32353ece07 100644 --- a/src/sagemaker/workflow/function_step.py +++ b/src/sagemaker/workflow/function_step.py @@ -33,12 +33,11 @@ PipelineVariable, ) -from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.workflow.properties import Properties from sagemaker.workflow.retry import RetryPolicy from sagemaker.workflow.steps import Step, ConfigurableRetryStep, StepTypeEnum from sagemaker.workflow.step_collections import StepCollection -from sagemaker.workflow.step_outputs import StepOutput +from sagemaker.workflow.step_outputs import StepOutput, get_step from sagemaker.workflow.utilities import trim_request_dict, load_step_compilation_context from sagemaker.s3_utils import s3_path_join @@ -277,14 +276,12 @@ def _to_json_get(self) -> JsonGet: """Expression structure for workflow service calls using JsonGet resolution.""" from sagemaker.remote_function.core.stored_function import ( JSON_SERIALIZED_RESULT_KEY, - RESULTS_FOLDER, JSON_RESULTS_FILE, ) if not self._step.name: raise ValueError("Step name is not defined.") - s3_root_uri = self._step._job_settings.s3_root_uri # Resolve json path -- # Deserializer will be able to resolve a JsonGet using path "Return[1]" to # access value 10 from following serialized JSON: @@ -308,13 +305,9 @@ def _to_json_get(self) -> JsonGet: return JsonGet( s3_uri=Join( - "/", - [ - s3_root_uri, - ExecutionVariables.PIPELINE_NAME, - ExecutionVariables.PIPELINE_EXECUTION_ID, - self._step.name, - RESULTS_FOLDER, + on="/", + values=[ + get_step(self)._properties.OutputDataConfig.S3OutputPath, JSON_RESULTS_FILE, ], ), diff --git a/src/sagemaker/workflow/functions.py b/src/sagemaker/workflow/functions.py index 4f63c4651b..947578d433 100644 --- a/src/sagemaker/workflow/functions.py +++ b/src/sagemaker/workflow/functions.py @@ -21,7 +21,7 @@ from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.execution_variables import ExecutionVariable from sagemaker.workflow.parameters import Parameter -from sagemaker.workflow.properties import PropertyFile +from sagemaker.workflow.properties import PropertyFile, Properties if TYPE_CHECKING: from sagemaker.workflow.steps import Step @@ -172,9 +172,9 @@ def _validate_json_get_s3_uri(self): for join_arg in s3_uri.values: if not is_pipeline_variable(join_arg): continue - if not isinstance(join_arg, (Parameter, ExecutionVariable)): + if not isinstance(join_arg, (Parameter, ExecutionVariable, Properties)): raise ValueError( f"Invalid JsonGet function {self.expr}. " - f"The Join values in JsonGet's s3_uri can only be a primitive object " - f"or Parameter or ExecutionVariable." + f"The Join values in JsonGet's s3_uri can only be a primitive object, " + f"Parameter, ExecutionVariable or Properties." ) diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index 6800f2a3ac..510ccd76bf 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -41,7 +41,6 @@ RESOURCE_NOT_FOUND_EXCEPTION, EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT, ) -from sagemaker.workflow.function_step import DelayedReturn from sagemaker.workflow.lambda_step import LambdaOutput, LambdaStep from sagemaker.workflow.entities import ( Expression, @@ -725,10 +724,7 @@ def _interpolate( pipeline_name (str): The name of the pipeline to be interpolated. """ if isinstance(obj, (Expression, Parameter, Properties, StepOutput)): - updated_obj = _replace_pipeline_name_in_json_get_s3_uri( - obj=obj, pipeline_name=pipeline_name - ) - return updated_obj.expr + return obj.expr if isinstance(obj, CallbackOutput): step_name = callback_output_to_step_map[obj.output_name] @@ -760,27 +756,6 @@ def _interpolate( return new -# TODO: we should remove this once the ExecutionVariables.PIPELINE_NAME is fixed in backend -def _replace_pipeline_name_in_json_get_s3_uri(obj: Union[RequestType, Any], pipeline_name: str): - """Replace the ExecutionVariables.PIPELINE_NAME in DelayedReturn's JsonGet s3_uri - - with the pipeline_name, because ExecutionVariables.PIPELINE_NAME - is parsed as all lower-cased str in backend. - """ - if not isinstance(obj, DelayedReturn): - return obj - - json_get = obj._to_json_get() - - if not json_get.s3_uri: - return obj - # the s3 uri has to be a Join, which has been validated in JsonGet init - for i in range(len(json_get.s3_uri.values)): - if json_get.s3_uri.values[i] == ExecutionVariables.PIPELINE_NAME: - json_get.s3_uri.values[i] = pipeline_name - return json_get - - def _map_callback_outputs(steps: List[Step]): """Iterate over the provided steps, building a map of callback output parameters to step names. diff --git a/tests/integ/sagemaker/workflow/helpers.py b/tests/integ/sagemaker/workflow/helpers.py index 48e1e95734..20365ef169 100644 --- a/tests/integ/sagemaker/workflow/helpers.py +++ b/tests/integ/sagemaker/workflow/helpers.py @@ -33,7 +33,7 @@ def create_and_execute_pipeline( region_name, role, no_of_steps, - last_step_name, + last_step_name_prefix, execution_parameters, step_status, step_result_type=None, @@ -66,7 +66,7 @@ def create_and_execute_pipeline( len(execution_steps) == no_of_steps ), f"Expected {no_of_steps}, instead found {len(execution_steps)}" - assert last_step_name in execution_steps[0]["StepName"] + assert last_step_name_prefix in execution_steps[0]["StepName"] assert execution_steps[0]["StepStatus"] == step_status if step_result_type: result = execution.result(execution_steps[0]["StepName"]) diff --git a/tests/integ/sagemaker/workflow/test_selective_execution.py b/tests/integ/sagemaker/workflow/test_selective_execution.py index a2c0286c6a..a584c095d5 100644 --- a/tests/integ/sagemaker/workflow/test_selective_execution.py +++ b/tests/integ/sagemaker/workflow/test_selective_execution.py @@ -16,6 +16,7 @@ import pytest +from sagemaker.processing import ProcessingInput from tests.integ import DATA_DIR from sagemaker.sklearn import SKLearnProcessor from sagemaker.workflow.step_outputs import get_step @@ -84,7 +85,7 @@ def sum(a, b): region_name=region_name, role=role, no_of_steps=2, - last_step_name="sum", + last_step_name_prefix="sum", execution_parameters=dict(), step_status="Succeeded", step_result_type=int, @@ -97,7 +98,7 @@ def sum(a, b): region_name=region_name, role=role, no_of_steps=2, - last_step_name="sum", + last_step_name_prefix="sum", execution_parameters=dict(), step_status="Succeeded", step_result_type=int, @@ -115,7 +116,7 @@ def sum(a, b): pass -def test_selective_execution_of_regular_step_depended_by_function_step( +def test_selective_execution_of_regular_step_referenced_by_function_step( sagemaker_session, role, pipeline_name, @@ -168,7 +169,7 @@ def func_2(arg): region_name=region_name, role=role, no_of_steps=2, - last_step_name="func", + last_step_name_prefix="func", execution_parameters=dict(), step_status="Succeeded", step_result_type=str, @@ -182,7 +183,7 @@ def func_2(arg): region_name=region_name, role=role, no_of_steps=2, - last_step_name="func", + last_step_name_prefix="func", execution_parameters=dict(), step_status="Succeeded", step_result_type=str, @@ -199,3 +200,102 @@ def func_2(arg): pipeline.delete() except Exception: pass + + +def test_selective_execution_of_function_step_referenced_by_regular_step( + pipeline_session, + role, + pipeline_name, + region_name, + dummy_container_without_error, + sklearn_latest_version, +): + # Test Selective Pipeline Execution on function step -> [select: regular step] + os.environ["AWS_DEFAULT_REGION"] = region_name + processing_job_instance_counts = 2 + + @step( + name="step1", + role=role, + image_uri=dummy_container_without_error, + instance_type=INSTANCE_TYPE, + keep_alive_period_in_seconds=60, + ) + def func(var: int): + return 1, var + + step_output = func(processing_job_instance_counts) + + script_path = os.path.join(DATA_DIR, "dummy_script.py") + input_file_path = os.path.join(DATA_DIR, "dummy_input.txt") + inputs = [ + ProcessingInput(source=input_file_path, destination="/opt/ml/processing/inputs/"), + ] + + sklearn_processor = SKLearnProcessor( + framework_version=sklearn_latest_version, + role=role, + instance_type=INSTANCE_TYPE, + instance_count=step_output[1], + command=["python3"], + sagemaker_session=pipeline_session, + base_job_name="test-sklearn", + ) + + step_args = sklearn_processor.run( + inputs=inputs, + code=script_path, + ) + process_step = ProcessingStep( + name="MyProcessStep", + step_args=step_args, + ) + + pipeline = Pipeline( + name=pipeline_name, + steps=[process_step], + sagemaker_session=pipeline_session, + ) + + try: + execution, _ = create_and_execute_pipeline( + pipeline=pipeline, + pipeline_name=pipeline_name, + region_name=region_name, + role=role, + no_of_steps=2, + last_step_name_prefix=process_step.name, + execution_parameters=dict(), + step_status="Succeeded", + wait_duration=1000, # seconds + ) + + _, execution_steps2 = create_and_execute_pipeline( + pipeline=pipeline, + pipeline_name=pipeline_name, + region_name=region_name, + role=role, + no_of_steps=2, + last_step_name_prefix=process_step.name, + execution_parameters=dict(), + step_status="Succeeded", + wait_duration=1000, # seconds + selective_execution_config=SelectiveExecutionConfig( + source_pipeline_execution_arn=execution.arn, + selected_steps=[process_step.name], + ), + ) + + execution_proc_job = pipeline_session.describe_processing_job( + execution_steps2[0]["Metadata"]["ProcessingJob"]["Arn"].split("/")[-1] + ) + assert ( + execution_proc_job["ProcessingResources"]["ClusterConfig"]["InstanceCount"] + == processing_job_instance_counts + ) + + finally: + try: + pipeline.delete() + except Exception: + pass diff --git a/tests/integ/sagemaker/workflow/test_step_decorator.py b/tests/integ/sagemaker/workflow/test_step_decorator.py index bdd18a16f2..3c19a37cc3 100644 --- a/tests/integ/sagemaker/workflow/test_step_decorator.py +++ b/tests/integ/sagemaker/workflow/test_step_decorator.py @@ -159,7 +159,7 @@ def sum(a, b): region_name=region_name, role=role, no_of_steps=2, - last_step_name="sum", + last_step_name_prefix="sum", execution_parameters=dict(), step_status="Succeeded", step_result_type=int, @@ -203,7 +203,7 @@ def sum(a, b): region_name=region_name, role=role, no_of_steps=2, - last_step_name="sum", + last_step_name_prefix="sum", execution_parameters=dict(), step_status="Succeeded", step_result_type=int, @@ -252,7 +252,7 @@ def sum(a, b): region_name=region_name, role=role, no_of_steps=2, - last_step_name="sum", + last_step_name_prefix="sum", execution_parameters=dict(), step_status="Succeeded", step_result_type=int, @@ -297,7 +297,7 @@ def sum(a, b): region_name=region_name, role=role, no_of_steps=1, - last_step_name="sum", + last_step_name_prefix="sum", execution_parameters=dict(TrainingInstanceCount="ml.m5.xlarge"), step_status="Succeeded", step_result_type=int, @@ -386,7 +386,7 @@ def func_2(*args): region_name=region_name, role=role, no_of_steps=3, - last_step_name="func", + last_step_name_prefix="func", execution_parameters=dict(param_a=3), step_status="Succeeded", step_result_type=tuple, @@ -438,7 +438,7 @@ def validate_file_exists(files_exists, files_does_not_exist): region_name=region_name, role=role, no_of_steps=1, - last_step_name="validate_file_exists", + last_step_name_prefix="validate_file_exists", execution_parameters=dict(), step_status="Succeeded", ) @@ -493,7 +493,7 @@ def train(x): region_name=region_name, role=role, no_of_steps=1, - last_step_name="train", + last_step_name_prefix="train", execution_parameters=dict(), step_status="Succeeded", step_result_type=int, @@ -539,7 +539,7 @@ def cuberoot(x): region_name=region_name, role=role, no_of_steps=1, - last_step_name="cuberoot", + last_step_name_prefix="cuberoot", execution_parameters=dict(), step_status="Succeeded", step_result_type=numpy.float64, @@ -585,7 +585,7 @@ def divide(x, y): region_name=region_name, role=role, no_of_steps=1, - last_step_name="divide", + last_step_name_prefix="divide", execution_parameters=dict(), step_status="Failed", ) @@ -661,7 +661,7 @@ def func3(): region_name=region_name, role=role, no_of_steps=4, # The FailStep in else branch is not executed - last_step_name="MyConditionStep", + last_step_name_prefix="MyConditionStep", execution_parameters=dict(), step_status="Succeeded", ) @@ -733,7 +733,7 @@ def func(var: int): region_name=region_name, role=role, no_of_steps=2, - last_step_name=process_step.name, + last_step_name_prefix=process_step.name, execution_parameters=dict(), step_status="Succeeded", wait_duration=1000, # seconds @@ -846,7 +846,7 @@ def cuberoot(x): region_name=region_name, role=role, no_of_steps=1, - last_step_name="cuberoot", + last_step_name_prefix="cuberoot", execution_parameters=dict(), step_status="Succeeded", step_result_type=numpy.float64, @@ -890,7 +890,7 @@ def my_func(): region_name=region_name, role=role, no_of_steps=1, - last_step_name=get_step(step_a).name, + last_step_name_prefix=get_step(step_a).name, execution_parameters=dict(), step_status="Failed", ) @@ -950,7 +950,7 @@ def func_with_collision(var: str): region_name=region_name, role=role, no_of_steps=2, - last_step_name=get_step(step_output_b).name, + last_step_name_prefix=get_step(step_output_b).name, execution_parameters=dict(), step_status="Succeeded", step_result_type=str, diff --git a/tests/unit/sagemaker/workflow/test_condition_step.py b/tests/unit/sagemaker/workflow/test_condition_step.py index 315d549cce..019b5561ca 100644 --- a/tests/unit/sagemaker/workflow/test_condition_step.py +++ b/tests/unit/sagemaker/workflow/test_condition_step.py @@ -626,11 +626,7 @@ def _get_expected_jsonget_expr(step_name: str, path: str): "Std:Join": { "On": "/", "Values": [ - "s3://s3_bucket/test-prefix", - "MyPipeline", - {"Get": "Execution.PipelineExecutionId"}, - step_name, - "results", + {"Get": f"Steps.{step_name}.OutputDataConfig.S3OutputPath"}, "results.json", ], } diff --git a/tests/unit/sagemaker/workflow/test_function_step.py b/tests/unit/sagemaker/workflow/test_function_step.py index 888635ae02..25109fdc97 100644 --- a/tests/unit/sagemaker/workflow/test_function_step.py +++ b/tests/unit/sagemaker/workflow/test_function_step.py @@ -303,15 +303,7 @@ def func() -> type_hint: @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) -@patch("sagemaker.remote_function.job._JobSettings") -def test_step_function_with_no_hint_on_return_values(mock_job_settings_ctr): - s3_root_uri = "s3://bucket" - mock_job_settings = Mock() - mock_job_settings.s3_root_uri = s3_root_uri - mock_job_settings.sagemaker_session = MOCKED_PIPELINE_CONFIG.sagemaker_session - - mock_job_settings_ctr.return_value = mock_job_settings - +def test_step_function_with_no_hint_on_return_values(): @step(name="step_name") def func(): return 1, 2, 3 @@ -330,11 +322,7 @@ def func(): "Std:Join": { "On": "/", "Values": [ - "s3://bucket", - {"Get": "Execution.PipelineName"}, - {"Get": "Execution.PipelineExecutionId"}, - "step_name", - "results", + {"Get": "Steps.step_name.OutputDataConfig.S3OutputPath"}, "results.json", ], } @@ -354,11 +342,7 @@ def func(): "Std:Join": { "On": "/", "Values": [ - "s3://bucket", - {"Get": "Execution.PipelineName"}, - {"Get": "Execution.PipelineExecutionId"}, - "step_name", - "results", + {"Get": "Steps.step_name.OutputDataConfig.S3OutputPath"}, "results.json", ], } @@ -366,8 +350,6 @@ def func(): } } - mock_job_settings_ctr.assert_called_once() - with pytest.raises(NotImplementedError): for _ in step_output: pass diff --git a/tests/unit/sagemaker/workflow/test_functions.py b/tests/unit/sagemaker/workflow/test_functions.py index 040899d883..61e1424bbf 100644 --- a/tests/unit/sagemaker/workflow/test_functions.py +++ b/tests/unit/sagemaker/workflow/test_functions.py @@ -279,8 +279,8 @@ def test_json_get_invalid_s3_uri_not_join(): def test_json_get_invalid_s3_uri_with_invalid_pipeline_variable(sagemaker_session): with pytest.raises(ValueError) as e: - JsonGet(s3_uri=Join(on="/", values=["s3:/", Properties(step_name="test")])) + JsonGet(s3_uri=Join(on="/", values=["s3:/", Join()])) assert ( - "The Join values in JsonGet's s3_uri can only be a primitive object or Parameter or ExecutionVariable." - in str(e.value) + "The Join values in JsonGet's s3_uri can only be a primitive object, " + "Parameter, ExecutionVariable or Properties." in str(e.value) ) diff --git a/tests/unit/sagemaker/workflow/test_pipeline.py b/tests/unit/sagemaker/workflow/test_pipeline.py index d658455d62..14c2d442eb 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline.py +++ b/tests/unit/sagemaker/workflow/test_pipeline.py @@ -33,7 +33,6 @@ from sagemaker.workflow.pipeline import ( Pipeline, PipelineGraph, - _replace_pipeline_name_in_json_get_s3_uri, ) from sagemaker.workflow.pipeline_context import _PipelineConfig from sagemaker.workflow.pipeline_definition_config import PipelineDefinitionConfig @@ -1009,19 +1008,3 @@ def func(): (parameter, False), (delayed_return, True), ] - - -@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) -@patch("sagemaker.remote_function.job._JobSettings", Mock()) -@pytest.mark.parametrize("obj, is_replaced", _generate_parameters_for_replace_pipeline_name_test()) -def test_replace_pipeline_name_in_json_get_s3_uri(obj, is_replaced): - updated_obj = _replace_pipeline_name_in_json_get_s3_uri( - obj=obj, - pipeline_name=_PIPELINE_NAME, - ) - if is_replaced: - assert updated_obj != obj - assert "Execution.PipelineName" not in str(updated_obj.expr) - assert _PIPELINE_NAME in str(updated_obj.expr) - else: - assert updated_obj == obj From c4d6c6525afe1b6a49f9cfc61fdc267cd9b12cec Mon Sep 17 00:00:00 2001 From: Anton Repushko Date: Fri, 1 Mar 2024 18:39:50 +0100 Subject: [PATCH 05/42] feat: Add AutoMLV2 support (#4461) * Add AutoMLV2 support * Improvements of the integration tests --------- Co-authored-by: Anton Repushko --- doc/api/training/automlv2.rst | 7 + doc/api/training/index.rst | 1 + src/sagemaker/__init__.py | 11 + src/sagemaker/automl/automlv2.py | 1433 +++++++++++++++++ src/sagemaker/config/__init__.py | 7 + src/sagemaker/config/config_schema.py | 38 + src/sagemaker/session.py | 173 +- tests/data/automl/data/CoLA.csv | 500 ++++++ .../automl/data/cifar10_subset/cat/0001.png | Bin 0 -> 2105 bytes .../automl/data/cifar10_subset/cat/0002.png | Bin 0 -> 2193 bytes .../automl/data/cifar10_subset/cat/0003.png | Bin 0 -> 2054 bytes .../automl/data/cifar10_subset/cat/0004.png | Bin 0 -> 2148 bytes .../automl/data/cifar10_subset/cat/0005.png | Bin 0 -> 2509 bytes .../automl/data/cifar10_subset/dog/0001.png | Bin 0 -> 2457 bytes .../automl/data/cifar10_subset/dog/0002.png | Bin 0 -> 2490 bytes .../automl/data/cifar10_subset/dog/0003.png | Bin 0 -> 2416 bytes .../automl/data/cifar10_subset/dog/0004.png | Bin 0 -> 2476 bytes .../automl/data/cifar10_subset/dog/0005.png | Bin 0 -> 2352 bytes .../automl/data/cifar10_subset/frog/0001.png | Bin 0 -> 2461 bytes .../automl/data/cifar10_subset/frog/0002.png | Bin 0 -> 2388 bytes .../automl/data/cifar10_subset/frog/0003.png | Bin 0 -> 2470 bytes .../automl/data/cifar10_subset/frog/0004.png | Bin 0 -> 2373 bytes .../automl/data/cifar10_subset/frog/0005.png | Bin 0 -> 2451 bytes tests/data/automl/data/customer_support.csv | 108 ++ tests/data/automl/data/sample_time_series.csv | 1336 +++++++++++++++ tests/integ/__init__.py | 1 + tests/integ/auto_ml_v2_utils.py | 129 ++ tests/integ/test_auto_ml_v2.py | 495 ++++++ tests/unit/__init__.py | 20 + .../unit/sagemaker/automl/test_auto_ml_v2.py | 1102 +++++++++++++ tests/unit/sagemaker/automl/test_config.py | 281 ++++ tests/unit/test_session.py | 256 +++ 32 files changed, 5896 insertions(+), 2 deletions(-) create mode 100644 doc/api/training/automlv2.rst create mode 100644 src/sagemaker/automl/automlv2.py create mode 100644 tests/data/automl/data/CoLA.csv create mode 100644 tests/data/automl/data/cifar10_subset/cat/0001.png create mode 100644 tests/data/automl/data/cifar10_subset/cat/0002.png create mode 100644 tests/data/automl/data/cifar10_subset/cat/0003.png create mode 100644 tests/data/automl/data/cifar10_subset/cat/0004.png create mode 100644 tests/data/automl/data/cifar10_subset/cat/0005.png create mode 100644 tests/data/automl/data/cifar10_subset/dog/0001.png create mode 100644 tests/data/automl/data/cifar10_subset/dog/0002.png create mode 100644 tests/data/automl/data/cifar10_subset/dog/0003.png create mode 100644 tests/data/automl/data/cifar10_subset/dog/0004.png create mode 100644 tests/data/automl/data/cifar10_subset/dog/0005.png create mode 100644 tests/data/automl/data/cifar10_subset/frog/0001.png create mode 100644 tests/data/automl/data/cifar10_subset/frog/0002.png create mode 100644 tests/data/automl/data/cifar10_subset/frog/0003.png create mode 100644 tests/data/automl/data/cifar10_subset/frog/0004.png create mode 100644 tests/data/automl/data/cifar10_subset/frog/0005.png create mode 100644 tests/data/automl/data/customer_support.csv create mode 100644 tests/data/automl/data/sample_time_series.csv create mode 100644 tests/integ/auto_ml_v2_utils.py create mode 100644 tests/integ/test_auto_ml_v2.py create mode 100644 tests/unit/sagemaker/automl/test_auto_ml_v2.py create mode 100644 tests/unit/sagemaker/automl/test_config.py diff --git a/doc/api/training/automlv2.rst b/doc/api/training/automlv2.rst new file mode 100644 index 0000000000..212918a4f4 --- /dev/null +++ b/doc/api/training/automlv2.rst @@ -0,0 +1,7 @@ +AutoMLV2 +-------- + +.. automodule:: sagemaker.automl.automlv2 + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/api/training/index.rst b/doc/api/training/index.rst index e59491217d..5f85359d20 100644 --- a/doc/api/training/index.rst +++ b/doc/api/training/index.rst @@ -8,6 +8,7 @@ Training APIs algorithm analytics automl + automlv2 debugger estimators tuner diff --git a/src/sagemaker/__init__.py b/src/sagemaker/__init__.py index 6f021c8657..a1769b5a4c 100644 --- a/src/sagemaker/__init__.py +++ b/src/sagemaker/__init__.py @@ -61,6 +61,17 @@ from sagemaker.automl.automl import AutoML, AutoMLJob, AutoMLInput # noqa: F401 from sagemaker.automl.candidate_estimator import CandidateEstimator, CandidateStep # noqa: F401 +from sagemaker.automl.automlv2 import ( # noqa: F401 + AutoMLV2, + AutoMLJobV2, + LocalAutoMLDataChannel, + AutoMLDataChannel, + AutoMLTimeSeriesForecastingConfig, + AutoMLImageClassificationConfig, + AutoMLTabularConfig, + AutoMLTextClassificationConfig, + AutoMLTextGenerationConfig, +) from sagemaker.debugger import ProfilerConfig, Profiler # noqa: F401 diff --git a/src/sagemaker/automl/automlv2.py b/src/sagemaker/automl/automlv2.py new file mode 100644 index 0000000000..c855414f0b --- /dev/null +++ b/src/sagemaker/automl/automlv2.py @@ -0,0 +1,1433 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""A class for SageMaker AutoML V2 Jobs.""" +from __future__ import absolute_import + +import logging +from dataclasses import dataclass +from typing import Dict, List, Optional, Union + +from sagemaker import Model, PipelineModel, s3 +from sagemaker.automl.candidate_estimator import CandidateEstimator +from sagemaker.config import ( + AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH, + AUTO_ML_KMS_KEY_ID_PATH, + AUTO_ML_ROLE_ARN_PATH, + AUTO_ML_VOLUME_KMS_KEY_ID_PATH, + AUTO_ML_VPC_CONFIG_PATH, +) +from sagemaker.job import _Job +from sagemaker.session import Session +from sagemaker.utils import Tags, format_tags, name_from_base, resolve_value_from_config + +logger = logging.getLogger("sagemaker") + + +@dataclass +class AutoMLTabularConfig(object): + """Configuration of a tabular problem. + + Args: + target_attribute_name (str): The name of the column in the tabular dataset + that contains the values to be predicted. + algorithms_config (list(str)): The selection of algorithms run on a dataset to train + the model candidates of an Autopilot job. + feature_specification_s3_uri (str): A URL to the Amazon S3 data source containing + selected features and specified data types from the input data source of an AutoML job. + generate_candidate_definitions_only (bool): Whether to generates + possible candidates without training the models. + mode (str): The method that AutoML job uses to train the model. + Valid values: AUTO or ENSEMBLING or HYPERPARAMETER_TUNING. + problem_type (str): Defines the type of supervised learning + available for the candidates. Available problem types are: + `BinaryClassification`, `MulticlassClassification` and `Regression`. + sample_weight_attribute_name (str): The name of dataset column representing + sample weights. + max_candidates (int): The maximum number of training jobs allowed to run. + max_runtime_per_training_job_in_seconds (int): The maximum time, in seconds, + that each training job executed inside hyperparameter tuning + is allowed to run as part of a hyperparameter tuning job. + max_total_job_runtime_in_seconds (int): The total wait time of an AutoML job. + """ + + target_attribute_name: str + algorithms_config: Optional[List[str]] = None + feature_specification_s3_uri: Optional[str] = None + generate_candidate_definitions_only: Optional[bool] = None + mode: Optional[str] = None + problem_type: Optional[str] = None + sample_weight_attribute_name: Optional[str] = None + max_candidates: Optional[int] = None + max_runtime_per_training_job_in_seconds: Optional[int] = None + max_total_job_runtime_in_seconds: Optional[int] = None + + @classmethod + def from_response_dict(cls, api_problem_type_config: dict): + """Convert the API response to the native object.""" + completion_criteria = api_problem_type_config.get("CompletionCriteria", {}) + return cls( + max_candidates=completion_criteria.get("MaxCandidates"), + max_runtime_per_training_job_in_seconds=completion_criteria.get( + "MaxRuntimePerTrainingJobInSeconds" + ), + max_total_job_runtime_in_seconds=completion_criteria.get( + "MaxAutoMLJobRuntimeInSeconds" + ), + algorithms_config=api_problem_type_config.get("CandidateGenerationConfig", {}) + .get("AlgorithmsConfig", [{}])[0] + .get("AutoMLAlgorithms", None), + feature_specification_s3_uri=api_problem_type_config.get("FeatureSpecificationS3Uri"), + mode=api_problem_type_config.get("Mode"), + generate_candidate_definitions_only=api_problem_type_config.get( + "GenerateCandidateDefinitionsOnly", None + ), + problem_type=api_problem_type_config.get("ProblemType"), + target_attribute_name=api_problem_type_config.get("TargetAttributeName"), + sample_weight_attribute_name=api_problem_type_config.get("SampleWeightAttributeName"), + ) + + def to_request_dict(self): + """Convert the native object to the API request format.""" + config = {} + if _is_completion_criteria_exists_in_config( + max_candidates=self.max_candidates, + max_runtime_per_training_job_in_seconds=self.max_runtime_per_training_job_in_seconds, + max_total_job_runtime_in_seconds=self.max_total_job_runtime_in_seconds, + ): + config["CompletionCriteria"] = _completion_criteria_to_request_dict( + self.max_candidates, + self.max_runtime_per_training_job_in_seconds, + self.max_total_job_runtime_in_seconds, + ) + config["TargetAttributeName"] = self.target_attribute_name + if self.problem_type is not None: + config["ProblemType"] = self.problem_type + if self.sample_weight_attribute_name is not None: + config["SampleWeightAttributeName"] = self.sample_weight_attribute_name + if self.mode is not None: + config["Mode"] = self.mode + if self.generate_candidate_definitions_only is not None: + config["GenerateCandidateDefinitionsOnly"] = self.generate_candidate_definitions_only + if self.feature_specification_s3_uri is not None: + config["FeatureSpecificationS3Uri"] = self.feature_specification_s3_uri + + if self.algorithms_config is not None: + config["CandidateGenerationConfig"] = { + "AlgorithmsConfig": [{"AutoMLAlgorithms": self.algorithms_config}] + } + return {"TabularJobConfig": config} + + +@dataclass +class AutoMLImageClassificationConfig(object): + """Configuration of an image classification problem. + + Args: + max_candidates (int): The maximum number of training jobs allowed to run. + max_runtime_per_training_job_in_seconds (int): The maximum time, in seconds, + that each training job executed inside hyperparameter tuning + is allowed to run as part of a hyperparameter tuning job. + max_total_job_runtime_in_seconds (int): The total wait time of an AutoML job. + """ + + max_candidates: Optional[int] = None + max_runtime_per_training_job_in_seconds: Optional[int] = None + max_total_job_runtime_in_seconds: Optional[int] = None + + @classmethod + def from_response_dict(cls, api_problem_type_config: dict): + """Convert the API response to the native object.""" + completion_criteria = api_problem_type_config.get("CompletionCriteria", {}) + return cls( + max_candidates=completion_criteria.get("MaxCandidates"), + max_runtime_per_training_job_in_seconds=completion_criteria.get( + "MaxRuntimePerTrainingJobInSeconds" + ), + max_total_job_runtime_in_seconds=completion_criteria.get( + "MaxAutoMLJobRuntimeInSeconds" + ), + ) + + def to_request_dict(self): + """Convert the native object to the API request format.""" + config = {} + if _is_completion_criteria_exists_in_config( + max_candidates=self.max_candidates, + max_runtime_per_training_job_in_seconds=self.max_runtime_per_training_job_in_seconds, + max_total_job_runtime_in_seconds=self.max_total_job_runtime_in_seconds, + ): + config["CompletionCriteria"] = _completion_criteria_to_request_dict( + self.max_candidates, + self.max_runtime_per_training_job_in_seconds, + self.max_total_job_runtime_in_seconds, + ) + return {"ImageClassificationJobConfig": config} + + +@dataclass +class AutoMLTextClassificationConfig(object): + """Configuration of a text classification problem. + + Args: + content_column (str): The name of the column used to provide the text to be classified. + It should not be the same as the target label column. + target_label_column (str): The name of the column used to provide the class labels. + It should not be same as the content column. + max_candidates (int): The maximum number of training jobs allowed to run. + max_runtime_per_training_job_in_seconds (int): The maximum time, in seconds, + that each training job executed inside hyperparameter tuning + is allowed to run as part of a hyperparameter tuning job. + max_total_job_runtime_in_seconds (int): The total wait time of an AutoML job. + """ + + content_column: str + target_label_column: str + max_candidates: Optional[int] = None + max_runtime_per_training_job_in_seconds: Optional[int] = None + max_total_job_runtime_in_seconds: Optional[int] = None + + @classmethod + def from_response_dict(cls, api_problem_type_config: dict): + """Convert the API response to the native object.""" + completion_criteria = api_problem_type_config.get("CompletionCriteria", {}) + return cls( + max_candidates=completion_criteria.get("MaxCandidates"), + max_runtime_per_training_job_in_seconds=completion_criteria.get( + "MaxRuntimePerTrainingJobInSeconds" + ), + max_total_job_runtime_in_seconds=completion_criteria.get( + "MaxAutoMLJobRuntimeInSeconds" + ), + content_column=api_problem_type_config["ContentColumn"], + target_label_column=api_problem_type_config["TargetLabelColumn"], + ) + + def to_request_dict(self): + """Convert the native object to the API request format.""" + config = {} + if _is_completion_criteria_exists_in_config( + max_candidates=self.max_candidates, + max_runtime_per_training_job_in_seconds=self.max_runtime_per_training_job_in_seconds, + max_total_job_runtime_in_seconds=self.max_total_job_runtime_in_seconds, + ): + config["CompletionCriteria"] = _completion_criteria_to_request_dict( + self.max_candidates, + self.max_runtime_per_training_job_in_seconds, + self.max_total_job_runtime_in_seconds, + ) + + config["ContentColumn"] = self.content_column + config["TargetLabelColumn"] = self.target_label_column + + return {"TextClassificationJobConfig": config} + + +@dataclass +class AutoMLTextGenerationConfig(object): + """Configuration of a text generation problem. + + Args: + base_model_name (str): The name of the base model to fine-tune. + Autopilot supports fine-tuning a variety of large language models. + For information on the list of supported models, see Text generation models supporting + fine-tuning in Autopilot: + https://docs.aws.amazon.com/sagemaker/latest/dg/autopilot-llms-finetuning-models.html#autopilot-llms-finetuning-supported-llms. + If no BaseModelName is provided, the default model used is Falcon7BInstruct. + accept_eula (bool): Specifies agreement to the model end-user license agreement (EULA). + The AcceptEula value must be explicitly defined as True + in order to accept the EULA that this model requires. + For example, LLAMA2 requires to accept EULA. You are responsible for reviewing + and complying with any applicable license terms and making sure they are acceptable + for your use case before downloading or using a model. + text_generation_hyper_params (dict): The hyperparameters used to configure and optimize + the learning process of the base model. You can set any combination of the following + hyperparameters for all base models. Supported parameters are: + + - epochCount: The number of times the model goes through the entire training dataset. + - batchSize: The number of data samples used in each iteration of training. + - learningRate: The step size at which a model's parameters are updated during training. + - learningRateWarmupSteps: The number of training steps during which the learning rate + gradually increases before reaching its target or maximum value. + + max_candidates (int): The maximum number of training jobs allowed to run. + max_runtime_per_training_job_in_seconds (int): The maximum time, in seconds, + that each training job executed inside hyperparameter tuning + is allowed to run as part of a hyperparameter tuning job. + max_total_job_runtime_in_seconds (int): The total wait time of an AutoML job. + """ + + base_model_name: Optional[str] = None + accept_eula: Optional[bool] = None + text_generation_hyper_params: Optional[Dict[str, str]] = None + max_candidates: Optional[int] = None + max_runtime_per_training_job_in_seconds: Optional[int] = None + max_total_job_runtime_in_seconds: Optional[int] = None + + @classmethod + def from_response_dict(cls, api_problem_type_config: dict): + """Convert the API response to the native object.""" + completion_criteria = api_problem_type_config.get("CompletionCriteria", {}) + return cls( + max_candidates=completion_criteria.get("MaxCandidates"), + max_runtime_per_training_job_in_seconds=completion_criteria.get( + "MaxRuntimePerTrainingJobInSeconds" + ), + max_total_job_runtime_in_seconds=completion_criteria.get( + "MaxAutoMLJobRuntimeInSeconds" + ), + base_model_name=api_problem_type_config.get("BaseModelName"), + text_generation_hyper_params=api_problem_type_config.get( + "TextGenerationHyperParameters" + ), + accept_eula=api_problem_type_config.get("ModelAccessConfig", {}).get( + "AcceptEula", None + ), + ) + + def to_request_dict(self): + """Convert the native object to the API request format.""" + config = {} + if _is_completion_criteria_exists_in_config( + max_candidates=self.max_candidates, + max_runtime_per_training_job_in_seconds=self.max_runtime_per_training_job_in_seconds, + max_total_job_runtime_in_seconds=self.max_total_job_runtime_in_seconds, + ): + config["CompletionCriteria"] = {} + if self.max_candidates is not None: + config["CompletionCriteria"]["MaxCandidates"] = self.max_candidates + if self.max_runtime_per_training_job_in_seconds is not None: + config["CompletionCriteria"][ + "MaxRuntimePerTrainingJobInSeconds" + ] = self.max_runtime_per_training_job_in_seconds + if self.max_total_job_runtime_in_seconds is not None: + config["CompletionCriteria"][ + "MaxAutoMLJobRuntimeInSeconds" + ] = self.max_total_job_runtime_in_seconds + + if self.base_model_name is not None: + config["BaseModelName"] = self.base_model_name + if self.accept_eula is not None: + config["ModelAccessConfig"] = {"AcceptEula": self.accept_eula} + if self.text_generation_hyper_params is not None: + config["TextGenerationHyperParameters"] = self.text_generation_hyper_params + + return {"TextGenerationJobConfig": config} + + +@dataclass +class AutoMLTimeSeriesForecastingConfig(object): + """Configuration of a time series forecasting problem. + + Args: + forecast_frequency (str): The frequency of predictions in a forecast. + Valid intervals are an integer followed by Y (Year), + M (Month), W (Week), D (Day), H (Hour), and min (Minute). + For example, 1D indicates every day and 15min indicates every 15 minutes. + The value of a frequency must not overlap with the next larger frequency. + For example, you must use a frequency of 1H instead of 60min. + forecast_horizon (int): The number of time-steps that the model predicts. The forecast + horizon is also called the prediction length. The maximum forecast horizon + is the lesser of 500 time-steps or 1/4 of the time-steps in the dataset. + item_identifier_attribute_name (str): The name of the column that represents + the set of item identifiers for which you want to predict the target value. + target_attribute_name (str): The name of the column representing the target variable + that you want to predict for each item in your dataset. + The data type of the target variable must be numerical. + timestamp_attribute_name (str): The name of the column indicating a point in time at which + the target value of a given item is recorded. + grouping_attribute_names (list(str)): A set of columns names that can be grouped with the + item identifier column to create a composite key for which a target value is predicted. + feature_specification_s3_uri (str): A URL to the Amazon S3 data source containing + selected features and specified data types from the input data source of an AutoML job. + forecast_quantiles (list(str)): The quantiles used to train the model for forecasts + at a specified quantile. You can specify quantiles from 0.01 (p1) to 0.99 (p99), + by increments of 0.01 or higher. Up to five forecast quantiles can be specified. + When ForecastQuantiles is not provided, the AutoML job uses the quantiles p10, p50, + and p90 as default. + holiday_config (list(str)): The country code for the holiday calendar. + For the list of public holiday calendars supported by AutoML job V2, see Country Codes: + https://docs.aws.amazon.com/sagemaker/latest/dg/autopilot-timeseries-forecasting-holiday-calendars.html#holiday-country-codes. + Use the country code corresponding to the country of your choice. + aggregation (dict): A key value pair defining the aggregation method for a column, + where the key is the column name and the value is the aggregation method. + Aggregation is only supported for the target column. The supported aggregation methods + are sum (default), avg, first, min, max. + filling (dict): A key value pair defining the filling method for a column, + where the key is the column name and the value is an object which defines + the filling logic. You can specify multiple filling methods for a single column. + The supported filling methods and their corresponding options are: + + - frontfill: none (Supported only for target column) + - middlefill: zero, value, median, mean, min, max + - backfill: zero, value, median, mean, min, max + - futurefill: zero, value, median, mean, min, max + + To set a filling method to a specific value, set the fill parameter to + the chosen filling method value (for example "backfill" : "value"), + and define the filling value in an additional parameter prefixed with "_value". + For example, to set backfill to a value of 2, you must include two parameters: + "backfill": "value" and "backfill_value":"2". + max_candidates (int): The maximum number of training jobs allowed to run. + max_runtime_per_training_job_in_seconds (int): The maximum time, in seconds, + that each training job executed inside hyperparameter tuning + is allowed to run as part of a hyperparameter tuning job. + max_total_job_runtime_in_seconds (int): The total wait time of an AutoML job. + """ + + forecast_frequency: str + forecast_horizon: int + item_identifier_attribute_name: str + target_attribute_name: str + timestamp_attribute_name: str + grouping_attribute_names: Optional[List[str]] = None + feature_specification_s3_uri: Optional[str] = None + forecast_quantiles: Optional[List[str]] = None + holiday_config: Optional[List[str]] = None + aggregation: Optional[Dict[str, str]] = None + filling: Optional[Dict[str, str]] = None + max_candidates: Optional[int] = None + max_runtime_per_training_job_in_seconds: Optional[int] = None + max_total_job_runtime_in_seconds: Optional[int] = None + + @classmethod + def from_response_dict(cls, api_problem_type_config: dict): + """Convert the API response to the native object.""" + completion_criteria = api_problem_type_config.get("CompletionCriteria", {}) + return cls( + max_candidates=completion_criteria.get("MaxCandidates"), + max_runtime_per_training_job_in_seconds=completion_criteria.get( + "MaxRuntimePerTrainingJobInSeconds" + ), + max_total_job_runtime_in_seconds=completion_criteria.get( + "MaxAutoMLJobRuntimeInSeconds" + ), + feature_specification_s3_uri=api_problem_type_config.get("FeatureSpecificationS3Uri"), + forecast_frequency=api_problem_type_config["ForecastFrequency"], + forecast_horizon=api_problem_type_config["ForecastHorizon"], + item_identifier_attribute_name=api_problem_type_config["TimeSeriesConfig"][ + "ItemIdentifierAttributeName" + ], + target_attribute_name=api_problem_type_config["TimeSeriesConfig"][ + "TargetAttributeName" + ], + timestamp_attribute_name=api_problem_type_config["TimeSeriesConfig"][ + "TimestampAttributeName" + ], + forecast_quantiles=api_problem_type_config.get("ForecastQuantiles"), + aggregation=api_problem_type_config.get("Transformations", {}).get("Aggregation"), + filling=api_problem_type_config.get("Transformations", {}).get("Filling"), + grouping_attribute_names=api_problem_type_config.get("TimeSeriesConfig", {}).get( + "GroupingAttributeNames" + ), + holiday_config=api_problem_type_config.get("HolidayConfig", [{}])[0].get("CountryCode"), + ) + + def to_request_dict(self): + """Convert the native object to the API request format.""" + config = {} + if _is_completion_criteria_exists_in_config( + max_candidates=self.max_candidates, + max_runtime_per_training_job_in_seconds=self.max_runtime_per_training_job_in_seconds, + max_total_job_runtime_in_seconds=self.max_total_job_runtime_in_seconds, + ): + config["CompletionCriteria"] = _completion_criteria_to_request_dict( + self.max_candidates, + self.max_runtime_per_training_job_in_seconds, + self.max_total_job_runtime_in_seconds, + ) + + if self.feature_specification_s3_uri is not None: + config["FeatureSpecificationS3Uri"] = self.feature_specification_s3_uri + + config["ForecastHorizon"] = self.forecast_horizon + config["ForecastFrequency"] = self.forecast_frequency + config["TimeSeriesConfig"] = { + "TargetAttributeName": self.target_attribute_name, + "TimestampAttributeName": self.timestamp_attribute_name, + "ItemIdentifierAttributeName": self.item_identifier_attribute_name, + } + if self.grouping_attribute_names: + config["TimeSeriesConfig"]["GroupingAttributeNames"] = self.grouping_attribute_names + + if self.forecast_quantiles: + config["ForecastQuantiles"] = self.forecast_quantiles + + if self.holiday_config: + config["HolidayConfig"] = [] + config["HolidayConfig"].append({"CountryCode": self.holiday_config}) + + if self.aggregation or self.filling: + config["Transformations"] = {} + if self.aggregation: + config["Transformations"]["Aggregation"] = self.aggregation + if self.filling: + config["Transformations"]["Filling"] = self.filling + + return {"TimeSeriesForecastingJobConfig": config} + + +@dataclass +class AutoMLDataChannel(object): + """Class to represnt the datasource which will be used for mode training. + + Args: + s3_data_type (str): The data type for S3 data source. Valid values: ManifestFile, + AugmentedManifestFile or S3Prefix. + s3_uri (str): The URL to the Amazon S3 data source. The Uri refers to the Amazon S3 prefix + or ManifestFile depending on the data type. + channel_type (str): The type of channel. Valid values: `training` or `validation`. + Defines whether the data are used for training or validation. + The default value is training. + Channels for training and validation must share the same content_type. + compression_type (str): The compression type for input data. Gzip or None. + content_type (str): The content type of the data from the input source. + """ + + s3_data_type: str + s3_uri: str + channel_type: Optional[str] = None + compression_type: Optional[str] = None + content_type: Optional[str] = None + + @classmethod + def from_response_dict(cls, data_channel: dict): + """Convert the API response to the native object.""" + return cls( + s3_data_type=data_channel["DataSource"]["S3DataSource"]["S3DataType"], + s3_uri=data_channel["DataSource"]["S3DataSource"]["S3Uri"], + channel_type=data_channel.get("ChannelType"), + compression_type=data_channel.get("CompressionType"), + content_type=data_channel.get("ContentType"), + ) + + def to_request_dict(self): + """Convert the native object to the API request format.""" + request_dict = { + "DataSource": { + "S3DataSource": { + "S3DataType": self.s3_data_type, + "S3Uri": self.s3_uri, + } + }, + } + if self.channel_type: + request_dict["ChannelType"] = self.channel_type + if self.compression_type: + request_dict["CompressionType"] = self.compression_type + if self.content_type: + request_dict["ContentType"] = self.content_type + return request_dict + + +@dataclass +class LocalAutoMLDataChannel(object): + """Class to represnt a local datasource which will be uploaded to S3. + + Args: + data_type (str): The data type for S3 data source. Valid values: ManifestFile, + AugmentedManifestFile or S3Prefix. + path (str): The path to the local data which will be uploaded to S3. + channel_type (str): The type of channel. Valid values: `training` or `validation`. + Defines whether the data are used for training or validation. + The default value is training. + Channels for training and validation must share the same content_type. + compression_type (str): The compression type for input data. Gzip or None. + content_type (str): The content type of the data from the input source. + """ + + data_type: str + path: str + channel_type: Optional[str] = None + compression_type: Optional[str] = None + content_type: Optional[str] = None + + +def _upload_local_dataset( + local_dataset: LocalAutoMLDataChannel, sagemaker_session: Session +) -> AutoMLDataChannel: + """Method to upload a local dataset to the S3 and convert it to an AutoMLDataChannel object.""" + s3_path = sagemaker_session.upload_data(local_dataset.path, key_prefix="auto-ml-v2-input-data") + return AutoMLDataChannel( + s3_uri=s3_path, + s3_data_type=local_dataset.data_type, + channel_type=local_dataset.channel_type, + compression_type=local_dataset.compression_type, + content_type=local_dataset.content_type, + ) + + +def _is_completion_criteria_exists_in_config( + max_candidates: int = None, + max_runtime_per_training_job_in_seconds: int = None, + max_total_job_runtime_in_seconds: int = None, +) -> bool: + """Check is the completion criteria was provided as part of the problem config or not.""" + return ( + max_candidates is not None + or max_runtime_per_training_job_in_seconds is not None + or max_total_job_runtime_in_seconds is not None + ) + + +def _completion_criteria_to_request_dict( + max_candidates: int = None, + max_runtime_per_training_job_in_seconds: int = None, + max_total_job_runtime_in_seconds: int = None, +): + """Convert a completion criteria object to an API request format.""" + config = {} + if max_candidates is not None: + config["MaxCandidates"] = max_candidates + if max_runtime_per_training_job_in_seconds is not None: + config["MaxRuntimePerTrainingJobInSeconds"] = max_runtime_per_training_job_in_seconds + if max_total_job_runtime_in_seconds is not None: + config["MaxAutoMLJobRuntimeInSeconds"] = max_total_job_runtime_in_seconds + return config + + +class AutoMLV2(object): + """A class for creating and interacting with SageMaker AutoMLV2 jobs.""" + + def __init__( + self, + problem_config: Union[ + AutoMLTabularConfig, + AutoMLImageClassificationConfig, + AutoMLTextClassificationConfig, + AutoMLTextGenerationConfig, + AutoMLTimeSeriesForecastingConfig, + ], + base_job_name: Optional[str] = None, + output_path: Optional[str] = None, + job_objective: Optional[Dict[str, str]] = None, + validation_fraction: Optional[float] = None, + auto_generate_endpoint_name: Optional[bool] = None, + endpoint_name: Optional[str] = None, + output_kms_key: Optional[str] = None, + role: Optional[str] = None, + volume_kms_key: Optional[str] = None, + encrypt_inter_container_traffic: Optional[bool] = None, + vpc_config: Optional[Dict[str, List]] = None, + tags: Optional[Tags] = None, + sagemaker_session: Optional[Session] = None, + ): + """Initialize an AutoMLV2 object. + + Args: + problem_config (object): A collection of settings specific + to the problem type used to configure an AutoML job V2. + There must be one and only one config of the following type. + Supported problem types are: + + - Image Classification (sagemaker.automl.automlv2.ImageClassificationJobConfig), + - Tabular (sagemaker.automl.automlv2.TabularJobConfig), + - Text Classification (sagemaker.automl.automlv2.TextClassificationJobConfig), + - Text Generation (TextGenerationJobConfig), + - Time Series Forecasting ( + sagemaker.automl.automlv2.TimeSeriesForecastingJobConfig). + + base_job_name (str): The name of AutoML job. + The name must be unique to within the AWS account and is case-insensitive. + output_path (str): The Amazon S3 output path. Must be 128 characters or less. + job_objective (dict[str, str]): Defines the objective metric + used to measure the predictive quality of an AutoML job. + In the format of: {"MetricName": str}. Available metrics are listed here: + https://docs.aws.amazon.com/sagemaker/latest/dg/autopilot-metrics-validation.html + validation_fraction (float): A float that specifies the portion of + the input dataset to be used for validation. The value should be in (0, 1) range. + auto_generate_endpoint_name (bool): Whether to automatically generate + an endpoint name for a one-click Autopilot model deployment. + If set auto_generate_endpoint_name to True, do not specify the endpoint_name. + endpoint_name (str): Specifies the endpoint name to use for a one-click AutoML + model deployment if the endpoint name is not generated automatically. + Specify the endpoint_name if and only if + auto_generate_endpoint_name is set to False + output_kms_key (str): The AWS KMS encryption key ID for output data configuration + role (str): The ARN of the role that is used to create the job and access the data. + volume_kms_key (str): The key used to encrypt stored data. + encrypt_inter_container_traffic (bool): whether to use traffic encryption + between the container layers. + vpc_config (dict): Specifies a VPC that your training jobs and hosted models have + access to. Contents include "SecurityGroupIds" and "Subnets". + tags (Optional[Tags]): Tags to attach to this specific endpoint. + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. + + Returns: + AutoMLV2 object. + """ + self.base_job_name = base_job_name + self.problem_config = problem_config + self.job_objective = job_objective + self.validation_fraction = validation_fraction + self.auto_generate_endpoint_name = auto_generate_endpoint_name + self.endpoint_name = endpoint_name + self.output_path = output_path + self.sagemaker_session = sagemaker_session or Session() + + self.vpc_config = resolve_value_from_config( + vpc_config, + AUTO_ML_VPC_CONFIG_PATH, + sagemaker_session=self.sagemaker_session, + ) + self.volume_kms_key = resolve_value_from_config( + volume_kms_key, + AUTO_ML_VOLUME_KMS_KEY_ID_PATH, + sagemaker_session=self.sagemaker_session, + ) + self.output_kms_key = resolve_value_from_config( + output_kms_key, + AUTO_ML_KMS_KEY_ID_PATH, + sagemaker_session=self.sagemaker_session, + ) + self.role = resolve_value_from_config( + role, AUTO_ML_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session + ) + if not self.role: + # Originally IAM role was a required parameter. + # Now we marked that as Optional because we can fetch it from SageMakerConfig + # Because of marking that parameter as optional, we should validate if it is None, even + # after fetching the config. + raise ValueError("An AWS IAM role is required to create an AutoML job.") + + if isinstance(problem_config, AutoMLTabularConfig): + self._check_problem_type_and_job_objective(problem_config.problem_type, job_objective) + + self.encrypt_inter_container_traffic = resolve_value_from_config( + direct_input=encrypt_inter_container_traffic, + config_path=AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH, + default_value=False, + sagemaker_session=self.sagemaker_session, + ) + + if self.output_path is None: + self.output_path = s3.s3_path_join( + "s3://", + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, + with_end_slash=True, + ) + + self.tags = format_tags(tags) + self.sagemaker_session = sagemaker_session or Session() + + self.current_job_name = None + self.inputs = None + self.latest_auto_ml_job = None + self._auto_ml_job_desc = None + self._best_candidate = None + + def fit( + self, + inputs: Optional[ + Union[ + LocalAutoMLDataChannel, + AutoMLDataChannel, + List[LocalAutoMLDataChannel], + List[AutoMLDataChannel], + ] + ], + wait: bool = True, + logs: bool = True, + job_name: str = None, + ): + """Create an AutoML Job with the input dataset. + + Args: + inputs (LocalAutoMLDataChannel or list(LocalAutoMLDataChannel) or AutoMLDataChannel + or list(AutoMLDataChannel)): Local path or S3 Uri where the training data is stored. + Or an AutoMLDataChannel object. Or a list of AutoMLDataChannel objects. + If a local path in LocalAutoMLDataChannel is provided, + the dataset will be uploaded to an S3 location. + The list of AutoMLDataChannel objects is to specify the training or the validation + input source. Input source for training and validation + must share the same content type and target attribute name. + Minimum number of 1 item. Maximum number of 2 items for list[AutoMLDataChannel]. + wait (bool): Whether the call should wait until the job completes (default: True). + logs (bool): Whether to show the logs produced by the job. Only meaningful when wait + is True (default: True). if ``wait`` is False, ``logs`` will be set to False as + well. + job_name (str): The job name. If not specified, the estimator generates + a default job name, based on the training image name and current timestamp. + """ + if not wait and logs: + logs = False + logger.warning("Setting logs to False. logs is only meaningful when wait is True.") + + # upload data for users if provided local path with LocalAutoMLDataChannel + if isinstance(inputs, LocalAutoMLDataChannel): + inputs = _upload_local_dataset(inputs, self.sagemaker_session) + elif isinstance(inputs, list) and all( + isinstance(channel, LocalAutoMLDataChannel) for channel in inputs + ): + inputs = [_upload_local_dataset(channel, self.sagemaker_session) for channel in inputs] + + self._prepare_for_auto_ml_job(job_name=job_name) + self.inputs = inputs + self.latest_auto_ml_job = AutoMLJobV2.start_new(self, inputs) # pylint: disable=W0201 + if wait: + self.latest_auto_ml_job.wait(logs=logs) + + @classmethod + def attach(cls, auto_ml_job_name, sagemaker_session=None): + """Attach to an existing AutoML job. + + Creates and returns a AutoML bound to an existing automl job. + + Args: + auto_ml_job_name (str): AutoML job name + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (default: None). If not + specified, the one originally associated with the ``AutoML`` instance is used. + + Returns: + sagemaker.automl.AutoML: A ``AutoMLV2`` instance with the attached automl job. + + """ + sagemaker_session = sagemaker_session or Session() + + auto_ml_job_desc = sagemaker_session.describe_auto_ml_job_v2(auto_ml_job_name) + automl_job_tags = sagemaker_session.sagemaker_client.list_tags( + ResourceArn=auto_ml_job_desc["AutoMLJobArn"] + )["Tags"] + inputs = [ + AutoMLDataChannel.from_response_dict(channel) + for channel in auto_ml_job_desc["AutoMLJobInputDataConfig"] + ] + + problem_type = auto_ml_job_desc["AutoMLProblemTypeConfigName"] + problem_config = None + if problem_type == "ImageClassification": + problem_config = AutoMLImageClassificationConfig.from_response_dict( + auto_ml_job_desc["AutoMLProblemTypeConfig"]["ImageClassificationJobConfig"] + ) + elif problem_type == "TextClassification": + problem_config = AutoMLTextClassificationConfig.from_response_dict( + auto_ml_job_desc["AutoMLProblemTypeConfig"]["TextClassificationJobConfig"] + ) + elif problem_type == "TimeSeriesForecasting": + problem_config = AutoMLTimeSeriesForecastingConfig.from_response_dict( + auto_ml_job_desc["AutoMLProblemTypeConfig"]["TimeSeriesForecastingJobConfig"] + ) + elif problem_type == "Tabular": + problem_config = AutoMLTabularConfig.from_response_dict( + auto_ml_job_desc["AutoMLProblemTypeConfig"]["TabularJobConfig"] + ) + elif problem_type == "TextGeneration": + problem_config = AutoMLTextGenerationConfig.from_response_dict( + auto_ml_job_desc["AutoMLProblemTypeConfig"]["TextGenerationJobConfig"] + ) + + amlj = AutoMLV2( + role=auto_ml_job_desc["RoleArn"], + problem_config=problem_config, + output_path=auto_ml_job_desc["OutputDataConfig"]["S3OutputPath"], + output_kms_key=auto_ml_job_desc["OutputDataConfig"].get("KmsKeyId"), + base_job_name=auto_ml_job_name, + sagemaker_session=sagemaker_session, + volume_kms_key=auto_ml_job_desc.get("SecurityConfig", {}).get("VolumeKmsKeyId"), + # Do not override encrypt_inter_container_traffic from config because this info + # is pulled from an existing automl job + encrypt_inter_container_traffic=auto_ml_job_desc.get("SecurityConfig", {}).get( + "EnableInterContainerTrafficEncryption" + ), + vpc_config=auto_ml_job_desc.get("SecurityConfig", {}).get("VpcConfig"), + job_objective=auto_ml_job_desc.get("AutoMLJobObjective", {}), + auto_generate_endpoint_name=auto_ml_job_desc.get("ModelDeployConfig", {}).get( + "AutoGenerateEndpointName", False + ), + endpoint_name=auto_ml_job_desc.get("ModelDeployConfig", {}).get("EndpointName"), + validation_fraction=auto_ml_job_desc.get("DataSplitConfig", {}).get( + "ValidationFraction" + ), + tags=automl_job_tags, + ) + amlj.current_job_name = auto_ml_job_name + amlj.latest_auto_ml_job = auto_ml_job_name # pylint: disable=W0201 + amlj._auto_ml_job_desc = auto_ml_job_desc + amlj.inputs = inputs + return amlj + + def describe_auto_ml_job(self, job_name=None): + """Returns the job description of an AutoML job for the given job name. + + Args: + job_name (str): The name of the AutoML job to describe. + If None, will use object's latest_auto_ml_job name. + + Returns: + dict: A dictionary response with the AutoML Job description. + """ + if job_name is None: + job_name = self.current_job_name + self._auto_ml_job_desc = self.sagemaker_session.describe_auto_ml_job_v2(job_name=job_name) + return self._auto_ml_job_desc + + def best_candidate(self, job_name=None): + """Returns the best candidate of an AutoML job for a given name. + + Args: + job_name (str): The name of the AutoML job. If None, object's + _current_auto_ml_job_name will be used. + + Returns: + dict: A dictionary with information of the best candidate. + """ + if self._best_candidate: + return self._best_candidate + + if job_name is None: + job_name = self.current_job_name + if self._auto_ml_job_desc is None: + self._auto_ml_job_desc = self.sagemaker_session.describe_auto_ml_job_v2( + job_name=job_name + ) + elif self._auto_ml_job_desc["AutoMLJobName"] != job_name: + self._auto_ml_job_desc = self.sagemaker_session.describe_auto_ml_job_v2( + job_name=job_name + ) + + self._best_candidate = self._auto_ml_job_desc["BestCandidate"] + return self._best_candidate + + def list_candidates( + self, + job_name=None, + status_equals=None, + candidate_name=None, + candidate_arn=None, + sort_order=None, + sort_by=None, + max_results=None, + ): + """Returns the list of candidates of an AutoML job for a given name. + + Args: + job_name (str): The name of the AutoML job. If None, will use object's + _current_job name. + status_equals (str): Filter the result with candidate status, values could be + "Completed", "InProgress", "Failed", "Stopped", "Stopping" + candidate_name (str): The name of a specified candidate to list. + Default to None. + candidate_arn (str): The Arn of a specified candidate to list. + Default to None. + sort_order (str): The order that the candidates will be listed in result. + Default to None. + sort_by (str): The value that the candidates will be sorted by. + Default to None. + max_results (int): The number of candidates will be listed in results, + between 1 to 100. Default to None. If None, will return all the candidates. + + Returns: + list: A list of dictionaries with candidates information. + """ + if job_name is None: + job_name = self.current_job_name + + list_candidates_args = {"job_name": job_name} + + if status_equals: + list_candidates_args["status_equals"] = status_equals + if candidate_name: + list_candidates_args["candidate_name"] = candidate_name + if candidate_arn: + list_candidates_args["candidate_arn"] = candidate_arn + if sort_order: + list_candidates_args["sort_order"] = sort_order + if sort_by: + list_candidates_args["sort_by"] = sort_by + if max_results: + list_candidates_args["max_results"] = max_results + + return self.sagemaker_session.list_candidates(**list_candidates_args)["Candidates"] + + def create_model( + self, + name, + sagemaker_session=None, + candidate=None, + vpc_config=None, + enable_network_isolation=False, + model_kms_key=None, + predictor_cls=None, + inference_response_keys=None, + ): + """Creates a model from a given candidate or the best candidate from the job. + + Args: + name (str): The pipeline model name. + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (default: None). If not + specified, the one originally associated with the ``AutoML`` instance is used.: + candidate (CandidateEstimator or dict): a CandidateEstimator used for deploying + to a SageMaker Inference Pipeline. If None, the best candidate will + be used. If the candidate input is a dict, a CandidateEstimator will be + created from it. + vpc_config (dict): Specifies a VPC that your training jobs and hosted models have + access to. Contents include "SecurityGroupIds" and "Subnets". + enable_network_isolation (bool): Isolates the training container. No inbound or + outbound network calls can be made, except for calls between peers within a + training cluster for distributed training. Default: False + model_kms_key (str): KMS key ARN used to encrypt the repacked + model archive file if the model is repacked + predictor_cls (callable[string, sagemaker.session.Session]): A + function to call to create a predictor (default: None). If + specified, ``deploy()`` returns the result of invoking this + function on the created endpoint name. + inference_response_keys (list): List of keys for response content. The order of the + keys will dictate the content order in the response. + + Returns: + PipelineModel object. + """ + sagemaker_session = sagemaker_session or self.sagemaker_session + + if candidate is None: + candidate_dict = self.best_candidate() + candidate = CandidateEstimator(candidate_dict, sagemaker_session=sagemaker_session) + elif isinstance(candidate, dict): + candidate = CandidateEstimator(candidate, sagemaker_session=sagemaker_session) + + inference_containers = candidate.containers + + self.validate_and_update_inference_response(inference_containers, inference_response_keys) + + models = [] + + for container in inference_containers: + model = Model( + image_uri=container["Image"], + model_data=container["ModelDataUrl"], + role=self.role, + env=container["Environment"], + vpc_config=vpc_config, + sagemaker_session=sagemaker_session or self.sagemaker_session, + enable_network_isolation=enable_network_isolation, + model_kms_key=model_kms_key, + ) + models.append(model) + + pipeline = PipelineModel( + models=models, + role=self.role, + predictor_cls=predictor_cls, + name=name, + vpc_config=vpc_config, + enable_network_isolation=enable_network_isolation, + sagemaker_session=sagemaker_session or self.sagemaker_session, + ) + return pipeline + + def deploy( + self, + initial_instance_count, + instance_type, + serializer=None, + deserializer=None, + candidate=None, + sagemaker_session=None, + name=None, + endpoint_name=None, + tags=None, + wait=True, + vpc_config=None, + enable_network_isolation=False, + model_kms_key=None, + predictor_cls=None, + inference_response_keys=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, + ): + """Deploy a candidate to a SageMaker Inference Pipeline. + + Args: + initial_instance_count (int): The initial number of instances to run + in the ``Endpoint`` created from this ``Model``. + instance_type (str): The EC2 instance type to deploy this Model to. + For example, 'ml.p2.xlarge'. + serializer (:class:`~sagemaker.serializers.BaseSerializer`): A + serializer object, used to encode data for an inference endpoint + (default: None). If ``serializer`` is not None, then + ``serializer`` will override the default serializer. The + default serializer is set by the ``predictor_cls``. + deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A + deserializer object, used to decode data from an inference + endpoint (default: None). If ``deserializer`` is not None, then + ``deserializer`` will override the default deserializer. The + default deserializer is set by the ``predictor_cls``. + candidate (CandidateEstimator or dict): a CandidateEstimator used for deploying + to a SageMaker Inference Pipeline. If None, the best candidate will + be used. If the candidate input is a dict, a CandidateEstimator will be + created from it. + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (default: None). If not + specified, the one originally associated with the ``AutoML`` instance is used. + name (str): The pipeline model name. If None, a default model name will + be selected on each ``deploy``. + endpoint_name (str): The name of the endpoint to create (default: + None). If not specified, a unique endpoint name will be created. + tags (Optional[Tags]): The list of tags to attach to this + specific endpoint. + wait (bool): Whether the call should wait until the deployment of + model completes (default: True). + vpc_config (dict): Specifies a VPC that your training jobs and hosted models have + access to. Contents include "SecurityGroupIds" and "Subnets". + enable_network_isolation (bool): Isolates the training container. No inbound or + outbound network calls can be made, except for calls between peers within a + training cluster for distributed training. Default: False + model_kms_key (str): KMS key ARN used to encrypt the repacked + model archive file if the model is repacked + predictor_cls (callable[string, sagemaker.session.Session]): A + function to call to create a predictor (default: None). If + specified, ``deploy()`` returns the result of invoking this + function on the created endpoint name. + inference_response_keys (list): List of keys for response content. The order of the + keys will dictate the content order in the response. + volume_size (int): The size, in GB, of the ML storage volume attached to individual + inference instance associated with the production variant. Currenly only Amazon EBS + gp2 storage volumes are supported. + model_data_download_timeout (int): The timeout value, in seconds, to download and + extract model data from Amazon S3 to the individual inference instance associated + with this production variant. + container_startup_health_check_timeout (int): The timeout value, in seconds, for your + inference container to pass health check by SageMaker Hosting. For more information + about health check see: + https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests + + Returns: + callable[string, sagemaker.session.Session] or ``None``: + If ``predictor_cls`` is specified, the invocation of ``self.predictor_cls`` on + the created endpoint name. Otherwise, ``None``. + """ + sagemaker_session = sagemaker_session or self.sagemaker_session + model = self.create_model( + name=name, + sagemaker_session=sagemaker_session, + candidate=candidate, + inference_response_keys=inference_response_keys, + vpc_config=vpc_config, + enable_network_isolation=enable_network_isolation, + model_kms_key=model_kms_key, + predictor_cls=predictor_cls, + ) + + return model.deploy( + initial_instance_count=initial_instance_count, + instance_type=instance_type, + serializer=serializer, + deserializer=deserializer, + endpoint_name=endpoint_name, + kms_key=model_kms_key, + tags=format_tags(tags), + wait=wait, + volume_size=volume_size, + model_data_download_timeout=model_data_download_timeout, + container_startup_health_check_timeout=container_startup_health_check_timeout, + ) + + def _prepare_for_auto_ml_job(self, job_name=None): + """Set any values in the AutoMLJob that need to be set before creating request. + + Args: + job_name (str): The name of the AutoML job. If None, a job name will be + created from base_job_name or "sagemaker-auto-ml". + """ + if job_name is not None: + self.current_job_name = job_name + else: + if self.base_job_name: + base_name = self.base_job_name + else: + base_name = "automl" + # CreateAutoMLJob API validates that member length less than or equal to 32 + self.current_job_name = name_from_base(base_name, max_length=32) + + if self.output_path is None: + self.output_path = s3.s3_path_join( + "s3://", + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, + with_end_slash=True, + ) + + def _check_problem_type_and_job_objective(self, problem_type, job_objective): + """Validate if problem_type and job_objective are both None or are both provided. + + Args: + problem_type (str): The type of problem of this AutoMLJob. Valid values are + "Regression", "BinaryClassification", "MultiClassClassification". + job_objective (dict): AutoMLJob objective, contains "AutoMLJobObjectiveType" (optional), + "MetricName" and "Value". + + Raises (ValueError): raises ValueError if one of problem_type and job_objective is provided + while the other is None. + """ + if not (problem_type and job_objective) and (problem_type or job_objective): + raise ValueError( + "One of problem type and objective metric provided. " + "Either both of them should be provided or none of them should be provided." + ) + + @classmethod + def _get_supported_inference_keys(cls, container, default=None): + """Returns the inference keys supported by the container. + + Args: + container (dict): Dictionary representing container + default (object): The value to be returned if the container definition + has no marker environment variable + + Returns: + List of inference keys the container support or default + + Raises: + KeyError if the default is None and the container definition has + no marker environment variable SAGEMAKER_INFERENCE_SUPPORTED. + """ + try: + return [ + x.strip() + for x in container["Environment"]["SAGEMAKER_INFERENCE_SUPPORTED"].split(",") + ] + except KeyError: + if default is None: + raise + return default + + @classmethod + def _check_inference_keys(cls, inference_response_keys, containers): + """Checks if the pipeline supports the inference keys for the containers. + + Given inference response keys and list of containers, determines whether + the keys are supported. + + Args: + inference_response_keys (list): List of keys for inference response content. + containers (list): list of inference container. + + Raises: + ValueError, if one or more keys in inference_response_keys are not supported + the inference pipeline. + """ + if not inference_response_keys: + return + try: + supported_inference_keys = cls._get_supported_inference_keys(container=containers[-1]) + except KeyError: + raise ValueError( + "The inference model does not support selection of inference content beyond " + "it's default content. Please retry without setting " + "inference_response_keys key word argument." + ) + bad_keys = [] + for key in inference_response_keys: + if key not in supported_inference_keys: + bad_keys.append(key) + + if bad_keys: + raise ValueError( + "Requested inference output keys [{bad_keys_str}] are unsupported. " + "The supported inference keys are [{allowed_keys_str}]".format( + bad_keys_str=", ".join(bad_keys), + allowed_keys_str=", ".join(supported_inference_keys), + ) + ) + + @classmethod + def validate_and_update_inference_response(cls, inference_containers, inference_response_keys): + """Validates the requested inference keys and updates response content. + + On validation, also updates the inference containers to emit appropriate response + content in the inference response. + + Args: + inference_containers (list): list of inference containers + inference_response_keys (list): list of inference response keys + + Raises: + ValueError: if one or more of inference_response_keys are unsupported by the model + """ + if not inference_response_keys: + return + + cls._check_inference_keys(inference_response_keys, inference_containers) + + previous_container_output = None + + for container in inference_containers: + supported_inference_keys_container = cls._get_supported_inference_keys( + container, default=[] + ) + if not supported_inference_keys_container: + previous_container_output = None + continue + current_container_output = None + for key in inference_response_keys: + if key in supported_inference_keys_container: + current_container_output = ( + current_container_output + "," + key if current_container_output else key + ) + + if previous_container_output: + container["Environment"].update( + {"SAGEMAKER_INFERENCE_INPUT": previous_container_output} + ) + if current_container_output: + container["Environment"].update( + {"SAGEMAKER_INFERENCE_OUTPUT": current_container_output} + ) + previous_container_output = current_container_output + + +class AutoMLJobV2(_Job): + """A class for interacting with CreateAutoMLJobV2 API.""" + + def __init__(self, sagemaker_session, job_name, inputs): + self.inputs = inputs + self.job_name = job_name + super(AutoMLJobV2, self).__init__(sagemaker_session=sagemaker_session, job_name=job_name) + + @classmethod + def _get_auto_ml_v2_args(cls, auto_ml, inputs): + """Constructs a dict of arguments for an Amazon SageMaker AutoMLV2 job. + + Args: + auto_ml (sagemaker.automl.AutoMLV2): AutoMLV2 object + created by the user. + inputs (AutoMLDataChannel or list[AutoMLDataChannel]): + Parameters used when called + :meth:`~sagemaker.automl.AutoML.fit`. + + Returns: + Dict: dict for `sagemaker.session.Session.auto_ml` method + """ + config = cls._load_config(inputs, auto_ml) + auto_ml_args = config.copy() + auto_ml_args["job_name"] = auto_ml.current_job_name + auto_ml_args["job_objective"] = auto_ml.job_objective + auto_ml_args["tags"] = auto_ml.tags + + return auto_ml_args + + @classmethod + def start_new(cls, auto_ml, inputs): + """Create a new Amazon SageMaker AutoMLV2 job from auto_ml_v2 object. + + Args: + auto_ml (sagemaker.automl.AutoMLV2): AutoMLV2 object + created by the user. + inputs (AutoMLDataChannel or list[AutoMLDataChannel]): + Parameters used when called + :meth:`~sagemaker.automl.AutoML.fit`. + + Returns: + sagemaker.automl.AutoMLJobV2: Constructed object that captures + all information about the started AutoMLV2 job. + """ + auto_ml_args = cls._get_auto_ml_v2_args(auto_ml, inputs) + + auto_ml.sagemaker_session.create_auto_ml_v2(**auto_ml_args) + return cls(auto_ml.sagemaker_session, auto_ml.current_job_name, inputs) + + @classmethod + def _load_config(cls, inputs, auto_ml, expand_role=True): + """Load job_config, input_config and output config from auto_ml and inputs. + + Args: + inputs (AutoMLDataChannel or list[AutoMLDataChannel]): Parameters used when called + :meth:`~sagemaker.automl.AutoML.fit`. + auto_ml (AutoMLV2): an AutoMLV2 object that user initiated. + expand_role (str): The expanded role arn that allows for Sagemaker + executionts. + validate_uri (bool): indicate whether to validate the S3 uri. + + Returns (dict): a config dictionary that contains input_config, output_config, + problem_config and role information. + + """ + + if not inputs: + msg = ( + "Cannot format input {}. Expecting an AutoMLDataChannel or " + "a list of AutoMLDataChannel or a LocalAutoMLDataChannel or a list of " + "LocalAutoMLDataChannel." + ) + raise ValueError(msg.format(inputs)) + + if isinstance(inputs, AutoMLDataChannel): + input_config = [inputs.to_request_dict()] + elif isinstance(inputs, list) and all( + isinstance(channel, AutoMLDataChannel) for channel in inputs + ): + input_config = [channel.to_request_dict() for channel in inputs] + + output_config = _Job._prepare_output_config(auto_ml.output_path, auto_ml.output_kms_key) + role = auto_ml.sagemaker_session.expand_role(auto_ml.role) if expand_role else auto_ml.role + + problem_config = auto_ml.problem_config.to_request_dict() + + config = { + "input_config": input_config, + "output_config": output_config, + "problem_config": problem_config, + "role": role, + "job_objective": auto_ml.job_objective, + } + + if ( + auto_ml.volume_kms_key + or auto_ml.vpc_config + or auto_ml.encrypt_inter_container_traffic is not None + ): + config["security_config"] = {} + if auto_ml.volume_kms_key: + config["security_config"]["VolumeKmsKeyId"] = auto_ml.volume_kms_key + if auto_ml.vpc_config: + config["security_config"]["VpcConfig"] = auto_ml.vpc_config + if auto_ml.encrypt_inter_container_traffic is not None: + config["security_config"][ + "EnableInterContainerTrafficEncryption" + ] = auto_ml.encrypt_inter_container_traffic + + # Model deploy config + + auto_ml_model_deploy_config = {} + if auto_ml.auto_generate_endpoint_name is not None: + auto_ml_model_deploy_config[ + "AutoGenerateEndpointName" + ] = auto_ml.auto_generate_endpoint_name + if not auto_ml.auto_generate_endpoint_name and auto_ml.endpoint_name is not None: + auto_ml_model_deploy_config["EndpointName"] = auto_ml.endpoint_name + + if auto_ml_model_deploy_config: + config["model_deploy_config"] = auto_ml_model_deploy_config + # Data split config + if auto_ml.validation_fraction is not None: + config["data_split_config"] = {"ValidationFraction": auto_ml.validation_fraction} + return config + + def describe(self): + """Returns a response from the DescribeAutoMLJobV2 API call.""" + return self.sagemaker_session.describe_auto_ml_job_v2(job_name=self.job_name) + + def wait(self, logs=True): + """Wait for the AutoML job to finish. + + Args: + logs (bool): indicate whether to output logs. + """ + if logs: + self.sagemaker_session.logs_for_auto_ml_job(job_name=self.job_name, wait=True) + else: + self.sagemaker_session.wait_for_auto_ml_job(job_name=self.job_name) diff --git a/src/sagemaker/config/__init__.py b/src/sagemaker/config/__init__.py index 860cfdba06..93919e5650 100644 --- a/src/sagemaker/config/__init__.py +++ b/src/sagemaker/config/__init__.py @@ -47,9 +47,12 @@ MONITORING_SCHEDULE, MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH, AUTO_ML_ROLE_ARN_PATH, + AUTO_ML_V2_ROLE_ARN_PATH, AUTO_ML_OUTPUT_CONFIG_PATH, + AUTO_ML_V2_OUTPUT_CONFIG_PATH, AUTO_ML_JOB_CONFIG_PATH, AUTO_ML_JOB, + AUTO_ML_JOB_V2, COMPILATION_JOB_ROLE_ARN_PATH, COMPILATION_JOB_OUTPUT_CONFIG_PATH, COMPILATION_JOB_VPC_CONFIG_PATH, @@ -111,9 +114,13 @@ FEATURE_GROUP_OFFLINE_STORE_KMS_KEY_ID_PATH, FEATURE_GROUP_ONLINE_STORE_KMS_KEY_ID_PATH, AUTO_ML_KMS_KEY_ID_PATH, + AUTO_ML_V2_KMS_KEY_ID_PATH, AUTO_ML_VPC_CONFIG_PATH, + AUTO_ML_V2_VPC_CONFIG_PATH, AUTO_ML_VOLUME_KMS_KEY_ID_PATH, + AUTO_ML_V2_VOLUME_KMS_KEY_ID_PATH, AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH, + AUTO_ML_V2_INTER_CONTAINER_ENCRYPTION_PATH, ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH, SESSION_DEFAULT_S3_BUCKET_PATH, SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH, diff --git a/src/sagemaker/config/config_schema.py b/src/sagemaker/config/config_schema.py index 3f5352cc1f..35c4859930 100644 --- a/src/sagemaker/config/config_schema.py +++ b/src/sagemaker/config/config_schema.py @@ -83,6 +83,7 @@ ENDPOINT = "Endpoint" INFERENCE_COMPONENT = "InferenceComponent" AUTO_ML_JOB = "AutoMLJob" +AUTO_ML_JOB_V2 = "AutoMLJobV2" COMPILATION_JOB = "CompilationJob" CUSTOM_PARAMETERS = "CustomParameters" PIPELINE = "Pipeline" @@ -182,14 +183,21 @@ def _simple_path(*args: str): FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, SECURITY_CONFIG, KMS_KEY_ID ) AUTO_ML_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB, OUTPUT_DATA_CONFIG) +AUTO_ML_V2_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB_V2, OUTPUT_DATA_CONFIG) AUTO_ML_KMS_KEY_ID_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB, OUTPUT_DATA_CONFIG, KMS_KEY_ID) +AUTO_ML_V2_KMS_KEY_ID_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB_V2, OUTPUT_DATA_CONFIG, KMS_KEY_ID) AUTO_ML_VOLUME_KMS_KEY_ID_PATH = _simple_path( SAGEMAKER, AUTO_ML_JOB, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, VOLUME_KMS_KEY_ID ) +AUTO_ML_V2_VOLUME_KMS_KEY_ID_PATH = _simple_path( + SAGEMAKER, AUTO_ML_JOB_V2, SECURITY_CONFIG, VOLUME_KMS_KEY_ID +) AUTO_ML_ROLE_ARN_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB, ROLE_ARN) +AUTO_ML_V2_ROLE_ARN_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB_V2, ROLE_ARN) AUTO_ML_VPC_CONFIG_PATH = _simple_path( SAGEMAKER, AUTO_ML_JOB, AUTO_ML_JOB_CONFIG, SECURITY_CONFIG, VPC_CONFIG ) +AUTO_ML_V2_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB_V2, SECURITY_CONFIG, VPC_CONFIG) AUTO_ML_JOB_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB, AUTO_ML_JOB_CONFIG) MONITORING_JOB_DEFINITION_PREFIX = _simple_path( SAGEMAKER, @@ -362,6 +370,12 @@ def _simple_path(*args: str): SECURITY_CONFIG, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, ) +AUTO_ML_V2_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path( + SAGEMAKER, + AUTO_ML_JOB_V2, + SECURITY_CONFIG, + ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, +) PROCESSING_JOB_ENVIRONMENT_PATH = _simple_path(SAGEMAKER, PROCESSING_JOB, ENVIRONMENT) PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path( SAGEMAKER, PROCESSING_JOB, NETWORK_CONFIG, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION @@ -947,6 +961,30 @@ def _simple_path(*args: str): TAGS: {"$ref": "#/definitions/tags"}, }, }, + # Auto ML V2 + # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateAutoMLJobV2.html + AUTO_ML_JOB_V2: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + SECURITY_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: {TYPE: "boolean"}, + VOLUME_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}, + VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"}, + }, + }, + OUTPUT_DATA_CONFIG: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: {KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}}, + }, + ROLE_ARN: {"$ref": "#/definitions/roleArn"}, + TAGS: {"$ref": "#/definitions/tags"}, + }, + }, # Transform Job # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTransformJob.html TRANSFORM_JOB: { diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 2f39ceac39..5f3fa5e5a0 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -65,9 +65,12 @@ MONITORING_SCHEDULE, MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH, AUTO_ML_ROLE_ARN_PATH, + AUTO_ML_V2_ROLE_ARN_PATH, AUTO_ML_OUTPUT_CONFIG_PATH, + AUTO_ML_V2_OUTPUT_CONFIG_PATH, AUTO_ML_JOB_CONFIG_PATH, AUTO_ML_JOB, + AUTO_ML_JOB_V2, COMPILATION_JOB_ROLE_ARN_PATH, COMPILATION_JOB_OUTPUT_CONFIG_PATH, COMPILATION_JOB_VPC_CONFIG_PATH, @@ -2570,7 +2573,7 @@ def logs_for_auto_ml_job( # noqa: C901 - suppress complexity warning for this m exceptions.UnexpectedStatusException: If waiting and auto ml job fails. """ - description = _wait_until(lambda: self.describe_auto_ml_job(job_name), poll) + description = _wait_until(lambda: self.describe_auto_ml_job_v2(job_name), poll) instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init( self.boto_session, description, job="AutoML" @@ -2618,7 +2621,7 @@ def logs_for_auto_ml_job( # noqa: C901 - suppress complexity warning for this m if state == LogState.JOB_COMPLETE: state = LogState.COMPLETE elif time.time() - last_describe_job_call >= 30: - description = self.sagemaker_client.describe_auto_ml_job(AutoMLJobName=job_name) + description = self.sagemaker_client.describe_auto_ml_job_v2(AutoMLJobName=job_name) last_describe_job_call = time.time() status = description["AutoMLJobStatus"] @@ -2632,6 +2635,172 @@ def logs_for_auto_ml_job( # noqa: C901 - suppress complexity warning for this m if dot: print() + def create_auto_ml_v2( + self, + input_config, + job_name, + problem_config, + output_config, + job_objective=None, + model_deploy_config=None, + data_split_config=None, + role=None, + security_config=None, + tags=None, + ): + """Create an Amazon SageMaker AutoMLV2 job. + + Args: + input_config (list[dict]): A list of AutoMLDataChannel objects. + Each channel contains "DataSource" and other optional fields. + job_name (str): A string that can be used to identify an AutoMLJob. Each AutoMLJob + should have a unique job name. + problem_config (object): A collection of settings specific + to the problem type used to configure an AutoML job V2. + There must be one and only one config of the following type. + Supported problem types are: + + - Image Classification (sagemaker.automl.automlv2.ImageClassificationJobConfig), + - Tabular (sagemaker.automl.automlv2.TabularJobConfig), + - Text Classification (sagemaker.automl.automlv2.TextClassificationJobConfig), + - Text Generation (TextGenerationJobConfig), + - Time Series Forecasting ( + sagemaker.automl.automlv2.TimeSeriesForecastingJobConfig). + + output_config (dict): The S3 URI where you want to store the training results and + optional KMS key ID. + job_objective (dict): AutoMLJob objective, contains "AutoMLJobObjectiveType" (optional), + "MetricName" and "Value". + model_deploy_config (dict): Specifies how to generate the endpoint name + for an automatic one-click Autopilot model deployment. + Contains "AutoGenerateEndpointName" and "EndpointName" + data_split_config (dict): This structure specifies how to split the data + into train and validation datasets. + role (str): The Amazon Resource Name (ARN) of an IAM role that + Amazon SageMaker can assume to perform tasks on your behalf. + security_config (dict): The security configuration for traffic encryption + or Amazon VPC settings. + tags (Optional[Tags]): A list of dictionaries containing key-value + pairs. + """ + + role = resolve_value_from_config(role, AUTO_ML_V2_ROLE_ARN_PATH, sagemaker_session=self) + inferred_output_config = update_nested_dictionary_with_values_from_config( + output_config, AUTO_ML_V2_OUTPUT_CONFIG_PATH, sagemaker_session=self + ) + + auto_ml_job_v2_request = self._get_auto_ml_request_v2( + input_config=input_config, + job_name=job_name, + problem_config=problem_config, + output_config=inferred_output_config, + role=role, + job_objective=job_objective, + model_deploy_config=model_deploy_config, + data_split_config=data_split_config, + security_config=security_config, + tags=format_tags(tags), + ) + + def submit(request): + logger.info("Creating auto-ml-v2-job with name: %s", job_name) + logger.debug("auto ml v2 request: %s", json.dumps(request), indent=4) + print(json.dumps(request)) + self.sagemaker_client.create_auto_ml_job_v2(**request) + + self._intercept_create_request( + auto_ml_job_v2_request, submit, self.create_auto_ml_v2.__name__ + ) + + def _get_auto_ml_request_v2( + self, + input_config, + output_config, + job_name, + problem_config, + role, + job_objective=None, + model_deploy_config=None, + data_split_config=None, + security_config=None, + tags=None, + ): + """Constructs a request compatible for creating an Amazon SageMaker AutoML job. + + Args: + input_config (list[dict]): A list of Channel objects. Each channel contains "DataSource" + and "TargetAttributeName", "CompressionType" and "SampleWeightAttributeName" are + optional fields. + output_config (dict): The S3 URI where you want to store the training results and + optional KMS key ID. + job_name (str): A string that can be used to identify an AutoMLJob. Each AutoMLJob + should have a unique job name. + problem_config (object): A collection of settings specific + to the problem type used to configure an AutoML job V2. + There must be one and only one config of the following type. + Supported problem types are: + + - Image Classification (sagemaker.automl.automlv2.ImageClassificationJobConfig), + - Tabular (sagemaker.automl.automlv2.TabularJobConfig), + - Text Classification (sagemaker.automl.automlv2.TextClassificationJobConfig), + - Text Generation (TextGenerationJobConfig), + - Time Series Forecasting ( + sagemaker.automl.automlv2.TimeSeriesForecastingJobConfig). + + role (str): The Amazon Resource Name (ARN) of an IAM role that + Amazon SageMaker can assume to perform tasks on your behalf. + job_objective (dict): AutoMLJob objective, contains "AutoMLJobObjectiveType" (optional), + "MetricName" and "Value". + model_deploy_config (dict): Specifies how to generate the endpoint name + for an automatic one-click Autopilot model deployment. + Contains "AutoGenerateEndpointName" and "EndpointName" + data_split_config (dict): This structure specifies how to split the data + into train and validation datasets. + security_config (dict): The security configuration for traffic encryption + or Amazon VPC settings. + tags (Optional[Tags]): A list of dictionaries containing key-value + pairs. + + Returns: + Dict: a automl v2 request dict + """ + auto_ml_job_v2_request = { + "AutoMLJobName": job_name, + "AutoMLJobInputDataConfig": input_config, + "OutputDataConfig": output_config, + "AutoMLProblemTypeConfig": problem_config, + "RoleArn": role, + } + if job_objective is not None: + auto_ml_job_v2_request["AutoMLJobObjective"] = job_objective + if model_deploy_config is not None: + auto_ml_job_v2_request["ModelDeployConfig"] = model_deploy_config + if data_split_config is not None: + auto_ml_job_v2_request["DataSplitConfig"] = data_split_config + if security_config is not None: + auto_ml_job_v2_request["SecurityConfig"] = security_config + + tags = _append_project_tags(format_tags(tags)) + tags = self._append_sagemaker_config_tags( + tags, "{}.{}.{}".format(SAGEMAKER, AUTO_ML_JOB_V2, TAGS) + ) + if tags is not None: + auto_ml_job_v2_request["Tags"] = tags + + return auto_ml_job_v2_request + + # Done + def describe_auto_ml_job_v2(self, job_name): + """Calls the DescribeAutoMLJobV2 API for the given job name and returns the response. + + Args: + job_name (str): The name of the AutoML job to describe. + + Returns: + dict: A dictionary response with the AutoMLV2 Job description. + """ + return self.sagemaker_client.describe_auto_ml_job_v2(AutoMLJobName=job_name) + def compile_model( self, input_model_config, diff --git a/tests/data/automl/data/CoLA.csv b/tests/data/automl/data/CoLA.csv new file mode 100644 index 0000000000..beedd3115f --- /dev/null +++ b/tests/data/automl/data/CoLA.csv @@ -0,0 +1,500 @@ +text,label +Bill cooked the rice.,1 +Martha carved a toy for the baby.,1 +Susan whispered Rachel the news.,0 +Harry got to be as much of a celebrity as his father.,1 +That the fuzz wanted him worried John but it didn't worry Mary.,1 +George sang to himself.,1 +They shopped their way around New York.,1 +Every student who ever goes to Europe ever has enough money.,0 +Mary or Sue could tell you that.,1 +What Julie did of Lloyd was become fond.,0 +This sonata is easy to play on this piano.,1 +There was him in the garden.,0 +I brought a razor to shave myself with.,0 +John is taller than Pete is.,1 +He gave my binoculars to that girl.,1 +Monica covered a blanket over the baby.,0 +He attributed to a short circuit the fire which.,1 +The window was broken with a hammer.,1 +"If the telephone rang, it could ring itself silly.",1 +Sarah sang a ballad.,1 +"The more that pictures of him appear in the news, the more likely John is to get arrested.",1 +Bill floated down the river for hours.,1 +They expected us to should leave him.,0 +"It is important for you to be more careful, the more you eat.",1 +The vacuum cleaner frightens the child.,1 +Calvin will not eat the beef waffles.,1 +Carol cut the whole wheat bread.,1 +I served my guests.,1 +Bill floated down the river.,1 +"China is a country that Joe wants to visit, and he will too, if he gets enough money.",1 +The box contains the ball.,1 +Which boy works in a skyscraper and the girl in a quonset hut?,0 +Bill believes that Anna and he are similar.,1 +To which city and which conference did Bill go?,1 +She is proud.,1 +The reporters expected that the principal would fire some teacher.,1 +Andy promised us to go.,0 +He washed myself.,0 +We believed it to be a fountain in the park.,0 +John disappeared.,1 +It is the best students he gives this book.,0 +Maxwell isn't half the doctor that was here.,0 +Pat remembered the appointment and that it was important to be on time.,1 +Linda winked at the audience.,1 +They may grow as high as bamboo.,1 +John loves pictures of himself.,1 +Some people yell at dogs.,1 +"You should always lock your door, no matter how fancy the hotel.",1 +Lou hoped the umbrella in the closet.,0 +"Kim considered joining the navy, but I never considered.",0 +Susan begged to be allowed to sing in the concert.,1 +"The contractor will build a house for $ 100,000.",1 +Tom played it under the table.,1 +"You get angrier, the more we eat, don't you.",1 +Tess knocks at the door.,1 +Bill ate the peaches and Harry did the grapes.,0 +Gary and Kevin ran himself into exhaustion.,0 +Donna fixed a sandwich for me.,1 +I want to should eat plums.,0 +They understand that you will prefer for the girl to put a picture on your desk.,1 +Ann bought a first edition of Richard III for $1000.,1 +One of major factors affecting the value of diamonds was their weight.,1 +"While I might want to, this is the kind of thing that Harris has already suggested.",1 +Richard's gift of the helicopter to the hospital and of the bus to the school.,1 +What did Bill cook supper and wash?,0 +What a kind of actor is he?,0 +They may grow as much as bamboo high.,0 +What I meant was that you have done it really well.,1 +Benjamin gave the cloak 0 and sent the book to Lee,1 +The psychologist hates phonology.,1 +I believed she was pregnant,1 +The dragons have all been slain.,1 +Gwen exchanged the dress for a shirt.,1 +The effectiveness of teaching and learning depend on several factors.,1 +Frances has had her clean the drapes.,1 +Mary believes that Bill saw herself.,0 +"They're going to serve the guests something, but it's unclear what.",1 +All through the mountains raged a fire.,1 +"Medea thought that, after the executioner had left, Poseidon would be relieved.",1 +"John might pass the exam, and as might Bill.",1 +It is difficult for me to concentrate on calculus.,1 +The chair pushed.,0 +John is a decidedly too tall man.,0 +I shipped the bicycle from my house at the beach to my house in the mountains.,1 +We made them be rude.,1 +I don't remember what all I said?,1 +There run many people.,0 +John hit the stone against the wall.,1 +Any student in Mary's class happened to vote.,0 +Julie maintained her own ideas over the course of the argument.,1 +her will the race.,0 +Into which room did Jeeves sauntered?,1 +I broke the twig and the branch.,0 +David broke the window with a hammer.,1 +They praised the dedication in the volunteers.,0 +We talked about the fact that he was sick for days.,1 +Mary is such a wit that people are afraid of her.,1 +The tree trembled.,1 +The piano kicked a student.,1 +"Sorry, I gave last week.",1 +The official to whom Smith loaned the money has been indicted.,1 +Pat persuaded Leslie to be aggressive.,1 +"The lady singing with a boy is a genius, isn't he?",1 +That Jack sometimes slept is impossible.,1 +Spot ate a cat treat.,1 +Henry cleared the table of dishes.,1 +Jack eats caviar more than he eats mush.,1 +it convinced Bill that Mary should sleep.,0 +It is an argument that people think will never end in Egypt.,1 +Jessica squirted water at me.,1 +I want to go.,1 +Who did you see in Las Vegas?,1 +The artist drew the child with a pencil.,1 +He made a statement everyone thought was interesting and important.,1 +That we invaded Iraq really freaks me out.,1 +"After Henry had touched a sword, Webster did so.",1 +I see.,1 +Ellen told a story.,1 +For him to win the race would surprise them.,1 +Peter thinks that Cathy loves him.,1 +Ellen talked with Helen about the problem.,1 +There may exist a man in the park.,1 +I sang a song with Mary while you did so with Bill.,1 +The persons on whom we kept tabs all proved to be innocent.,0 +"The more people that arrive, the louder that it gets.",1 +I read him two statements about himself.,1 +The man I saw left.,1 +I want that Bill left to remain a secret.,0 +I saw him leaving the main building.,1 +I separated the egg yolk and the egg white.,1 +On his finger sparkled a magnificent diamond.,1 +Wash you!,0 +Doug cleared the table.,1 +Larry all hunted the foxes.,0 +Who has drunk my whiskey?,1 +Michelle kept the papers in the desk.,1 +More than two students attended every seminar.,1 +Aphrodite may quickly free the animals.,1 +He is filled with sincerity.,1 +I deny that that Bob has some money is certain.,0 +The student knows the answers.,1 +Ellen was conferring.,1 +Jack ate more of this than he ate of that.,1 +He's fool.,0 +I wonder has Mary worked for Microsoft.,1 +Truman visited yesterday you.,0 +The president abandoned the people that voted for him.,1 +Doug blew the building up.,1 +I twirled the dough into a pretzel.,1 +I put a book on it.,1 +I would prefer that he have not finished.,0 +"The pictures of Bill, she put on your desk.",1 +The children wails,0 +John put in the box.,0 +It is likely that John has loved Mary.,1 +Julie felt he was there,1 +Jake sent the box towards Carson.,0 +The judge offered a prize to the winner.,1 +Medea exclaimed if the potion was ready,0 +The correspondence school made Bill a good typist.,1 +"The ball, a man kicked.",1 +A medal was been given to the mayor by the sewer commissioner.,0 +Are you studying English syntax?,1 +Heidi thinks that Andy to eat salmon flavored candy bars.,0 +We made the claim that Perseus killed the Gorgon.,1 +You said that Anson thought that Julie had fainted,1 +There arose a great storm.,1 +she is the mother of John.,1 +I wonder who what bought?,0 +For Aphrodite to appear to be happy would be impossible.,1 +Maytag will give a brand-new dryer to the winner of the Mrs.,0 +"If Emma had left Hartfield, Mr Woodhouse would have been unhappy.",1 +"He never achieved anything, did he?",1 +Ellen warned that melons were selling.,1 +"Fluffy is sick, as nobody knows.",0 +I read his every book.,1 +What Mary did with Bill was sing a song.,1 +They seemed all to like John.,1 +I think Rosie loves magazine ads.,1 +Jessica sprayed paint onto the table.,1 +I put the table with books.,0 +Water filled the pail.,1 +George wrote a volume of poems in Latin for Jane.,1 +He let the cats which were whining out.,1 +The writers could so believe the boy.,1 +We rummaged papers through the desk.,0 +Everyone reported that Max and some lady disappeared.,1 +Anson believed to be happy.,0 +"When Bill smokes, all the more Susan hates him.",0 +Did you see Mary?,1 +To whom did you send the package?,1 +Julie and Jenny arrived first,1 +John is ready for you to inspect his bunk.,1 +That gangsters had bribed him was denied by the sheriff.,1 +John bought any picture of Queen Elizabeth that was on sale.,1 +Ellen argued Helen.,0 +Bill pounded the metal fiat.,1 +The jeweller copied the name on the ring.,1 +"The knife, which he threw into the sea had a gold handle.",1 +Fred must have been singing songs and probably was drinking beer.,1 +Books send easily to children.,0 +"John, who and whose friends you saw, is a fool.",0 +The ball kicked a man.,1 +The king thanked the man.,1 +I don't know whether to agree with him or not.,1 +He walked up the hill.,1 +Sincerity is an important quality.,1 +Mike will sing if you will sing.,1 +I both went to the store and bought some whiskey.,1 +Tom tends to avoid confrontations.,1 +Sally eats pretty often the stuff.,0 +The boys swim.,0 +My pastor says I ate too much cake.,1 +I think Anita may have poisoned the meatballs which Sasha is gobbling down dumplings faster than I can reheat.,0 +John went to the store.,1 +The students called me a teacher.,1 +Rory eats.,1 +He treats John very nicely.,1 +It is not allowed to incriminate oneself.,1 +In the classroom John put the book on the table.,1 +John hit the ball with a bat.,1 +I know you ate asparagus.,1 +John scratched his arm and Mary did too.,1 +The Greeks arrived all.,0 +John deposited some money in the checking account on Friday and Mary did the same thing on Monday.,1 +The secretary transcribed the record with the speech.,0 +Benjamin thought he would give the cloak to Lee and the cloak to Lee he gave.,0 +"Whenever Bill smokes, Susan hates him a lot more.",1 +Two drops sanitize anything in your house.,1 +Kim must baked a cake.,0 +Whiskey do I drink.,0 +"Kathy likes astronomy, but she doesn't meteorology?",1 +I am both expecting to get the job and of the opinion that it is a desirable one.,1 +It is easy to slay the Gorgon.,1 +"They play unusual music, and I listen to unusual music.",1 +Have you seen my model airplane collection?,1 +The teacher meant well.,1 +Mary jumped the horse over the last fence perfectly.,1 +"That the fuzz wanted him worried John, but that the fuzz wanted her didn't worry Mary.",1 +It is ready to please John.,0 +Birds sings.,0 +Neither of these men is worthy to lead Italy.,1 +I lifted him up the books.,0 +It tried to bother me that Chris lied.,0 +"Mary thinks that John said that pictures of himself, Susan likes?",1 +What she did was e-mail all her friends.,1 +Paula swatted flies.,1 +They talked about the scandal for days.,1 +Brutus murdered Julius Caesar.,1 +I think it's time you give your lovely illness to someone else!,1 +How do you wonder what John bought?,0 +Imogen broke the vase.,1 +The socks are ready for you to go about beginning to put them on.,1 +I got phoned by a woman friend.,1 +John hopes to sleep.,1 +Marianne did not leave.,1 +I shaped the dough.,1 +I went to the store to have bought some whisky.,0 +John talked to everybody who came up to him at the party.,1 +John believed it that Bill was tardy.,0 +Merlin is a dangerous sorcerer.,1 +Barbara handed the intriguing results of the latest examination to Alan on Tuesday.,1 +Ron asked that the potion was ready,0 +They have no went.,0 +Bill is always complaining about the guys who work near him.,1 +She was sent to Seoul.,1 +Nobody likes us.,1 +John is taller than six feet.,1 +"Your desk before, this girl in the red coat will put a picture of Bill on your desk before tomorrow.",0 +We took the car to the town,1 +John has a hole in the upper right-hand corner of his quilt.,1 +The cup was broken by Louise.,1 +"They said that Tom would pay up, and pay up I believe that he did.",1 +It is to Cleveland that John drove the truck.,1 +Each of the boys fought with some of the other boys.,1 +Brenda and Molly met.,1 +Mary persuaded John to fix the computer.,1 +My eyes are itching me.,1 +The phone company billed $10 to me.,0 +John must do not have eaten.,0 +The price of meat fell.,1 +Augusta blamed herself for what happened.,1 +How fierce the battle?,0 +Shannon sent Dan an email.,1 +Those days Bill offered Mary anything he cooked.,1 +the person whom John gave the book to left.,1 +The books put on the table.,0 +Pat was neither recommended for promotion nor under any illusions about what that meant.,1 +The room contains few armchairs.,1 +"He'll no can do it, can he?",1 +Bob does not think that there is anyone from Greece in his basement.,1 +Wendy's mother country is Iceland.,1 +John saw Stephan,1 +Betsy loves herself in blue leather.,1 +The president declared Smith press secretary.,1 +John talked to any powerful politician.,0 +The boys should all could go,1 +Who did Kim work for and Sandy rely?,0 +we need to provide two trees and.,1 +An ancient treasure trove was found in this cave.,1 +Tom is confident that the elephants respect him.,1 +John decided Bill to get the prize.,0 +How did Sheila marry tall a man?,0 +Lenin believes the Tsar to be a power hungry dictator.,1 +It is this hat that I believe that he was wearing.,1 +Leigh threw the ball to Lane.,1 +Paris is no more,1 +I put the box the book.,0 +Which boy's guardian's did we elect employer president?,0 +John convinced Bill to visit Mary.,1 +George built both the houses.,1 +Gilgamesh might have been not reading the cuneiform tablets.,0 +Kim jogs over the hill.,1 +He is a fool.,1 +Heidi investigated whether John ate the cauliflower.,1 +He blew up the building.,1 +Kim beating his dog and alienates cats.,0 +I am not certain about if he will come.,0 +The captain ordered the troops to proceed.,1 +I love but you hate ice milk tea.,1 +One plan of which I got wind was calculated to keep us in suspense.,0 +Genie intoned that she was tired,1 +John sang a carol.,1 +That tall woman chased a cat.,1 +They were believed all to be quite diligent.,1 +You would have a reply if you come back tomorrow.,1 +The witch turned him into a frog.,1 +He left.,1 +The authorities blamed Greenpeace with the bombing.,0 +the girls likes herself.,0 +"She talked to Harry, but I don't know who else.",1 +Kim lives in the house Lee sold it to her.,0 +I funneled the bottle with the mixture.,0 +I fear tigers.,1 +It would be inconvenience to leave so soon.,1 +I forgot how good beer tastes.,1 +He washed himself.,1 +Jack eats more caviar than he sleeps.,0 +Jack is claiming that you won't need it.,1 +I called almost all of the men from Boston up.,1 +Zeke cooked and ate the chili.,1 +"Stories about him seem to show up more on the evening news, the more that John gets upset by them.",1 +"Since Jill said Joe had invited Sue, we didn't have to ask who.",0 +John can go to the market on his bike.,1 +Both the twins might have been at the party.,1 +Pick any of these flowers.,1 +I consider to be a fool the senator who made the opening speech.,0 +This girl in the red coat will eat her breakfast and will put a picture of Bill on your desk before tomorrow.,1 +Where all did they go for their holidays?,1 +John demanded that she stop phoning him.,1 +"Yes, she will.",1 +Few writers and any playwrights meet in Vienna.,0 +Who did you say that John thought would leave early?,1 +He put the money where Lee told him to put it.,1 +Alan told me who wanted to seem to be invincible.,1 +"Mary claimed that eating cabbage, Holly started.",0 +John bought a lot of books for his sons.,1 +The answer is known to Sue.,1 +Everyone claimed that the poison was neutralized.,1 +Tom's dog with one eye attacked Frank's with three legs.,0 +The cat trotted in the kitchen.,1 +I had a realization of there.,0 +To go and buying whiskey is not the solution to your problem.,0 +Adam asked if Hyacinth likes pineapples.,1 +This needs mending the shoe.,0 +I told Daniel that the exam was cancelled.,1 +Bill is resembled by John.,0 +He was arrested by the police.,1 +The stone knocked the pole into the road.,1 +Wilt is taller than I believe that Bill is.,1 +You may pick every flower.,1 +Strings have been pulled many times to get students into that university.,1 +There ran many people.,0 +There seems to be a nurse available.,1 +The only girl for whom it bothers me to wear that old fedora is Annabelle.,0 +Surprised me that you came early.,0 +We think that Bill left.,1 +Ida hunted deer in the woods.,1 +Sam gave the ball off the shelf.,0 +He's the happiest that we ever talked to the boy who had seen him.,0 +We associated their subsidiaries with our corporate office.,1 +Stephen believed there to be a fountain in the park.,1 +Brian wiped the fingerprints from outside the cupboard.,1 +I prefer for the girl to will win.,0 +The analysis of Lucy took longer than that of Gomez.,1 +They made him angry.,1 +"Mary questioned Joe's desire to eat cabbage, but only after I had questioned Sally's desire to.",0 +That's a most kind answer that I ever heard.,0 +That the march should go ahead and that it should be cancelled have been argued by different people at different times.,1 +John criticized.,0 +But these tickets are terrible!,1 +This violin is tough to play sonatas on.,1 +Everybody around here who ever buys anything on credit talks in his sleep.,1 +Every student in Mary's class is working on negative polarity.,1 +He's the happiest that I believe that he's ever been.,1 +Italy borders France.,1 +"Because he's such a nice guy, what would John like?",1 +"Every student, who wears socks, is a swinger.",0 +"You get angrier, the more we eat, don't we.",0 +That John was elected surprised Frank.,1 +He could might go,0 +"Pat wanted to try to go to Berne, and Chris to go to Rome. to Rome.",1 +Among the guests were sitting my friend Louise.,0 +Peter is the old pigs.,0 +He kicked him,1 +I have wanted to know exactly what happened to Rosa Luxemburg for many years.,1 +The child and her mother clung.,0 +John seems know about the bananas.,0 +I made to settle the matter my objective.,0 +What the stone did to the wall was hit it.,0 +Whose tax did the nurse polish her trombone and the plumber compute?,0 +Dorothy needs that dress as a costume.,1 +Crystal breaks at the slightest touch.,1 +This truck spread fewer salt than that one.,0 +"Whenever Bill smokes, Susan hates him far more.",1 +Dante accused,0 +John is eager to find a new home.,1 +The ball rolled down the hill.,1 +I felt that I know you.,1 +Has Bill eaten his tuna?,1 +Physicists like yourself are a godsend.,1 +That you will marry any student is not certain.,1 +No boys will put a picture of Bill on your desk before tomorrow.,1 +He is proud of his son's passing the bar exam.,1 +They lent a bicycle to me.,1 +"$100,000 will build a house.",1 +Mary believes that Bill saw himself.,1 +Mary managed John to go abroad.,0 +John tagged Lewis with a regulation baseball on Tuesday.,1 +I forgot to return the book that I borrowed from the teacher.,1 +The child put the toy on the table.,1 +This is the book that we had read.,1 +Anne is curious as to why her father sent her a telegram to America to return home at once.,1 +I went to the store and Nike bought some whisky.,1 +That girl was given my binoculars by him.,1 +Vote for you!,0 +The close brush with the law put the fear of god in him.,1 +John placed the flute the violin on the table.,0 +John is believed to be certain by everyone that Fred is crazy.,0 +whether she left is most unclear.,1 +He skated Penny around the rink.,1 +The cat trotted in.,1 +Joe is taller than I think Mary.,0 +"Mary asked me if, in St. Louis, John could rent a house cheap.",1 +Where did the policeman meet several young students?,1 +How did you eat the cake?,1 +Dogs chase.,0 +The old cart banged against the new cart.,1 +This is my favorite.,1 +Each of our rabbit and the neighbor's cat likes the other.,1 +It has rained every day for the last week.,1 +I am not certain about whether he will go or not.,1 +They preferred them arrested.,1 +Tourists admire paintings.,1 +I don't know the boy the flowers who Mary gave to.,0 +I saw the student of physics with long hair.,1 +"Although Mag doesn't eggplants, Sally eats cabbage.",0 +We persuaded them to examine them.,0 +The Dodgers beat the Red Sox and the Dodgers were beaten by the Giants.,1 +John promised Mary to shave herself.,0 +It was believed to be illegal by them to do that.,0 +John was carefully studying Russian.,1 +I have to try to finish grading some papers.,1 +News of Persephone and Demeter reach the great gods and goddesses of Olympus.,1 +I know a man who saw Bill and liked Mary.,1 +I've kicked more of a man than you have.,0 +Who ate what?,1 +So fast did he run that nobody could catch him.,1 +Leslie was in the flood zone.,1 +Volunteers praise easily.,0 +"As a statesman, scarcely could he do anything worth mentioning.",1 +Michelle kept the papers behind the desk.,1 +A variety of styles have been in vogue for the last year.,0 +We consider the men all fools.,1 +"Although Bob may not be a nut, many people have claimed and I think so too.",0 +Dan walked to New Mexico in the rain last year.,1 +Calvin will have been eating.,1 +We keep those censored copies of the book hidden to protect the sensibilities of the prudish.,1 +"It is important for the more you eat, the more careful you to be.",0 +There aren't many linguistics students here.,1 +There will be the hole in Jack's pocket.,0 +Susan doesn't eat enough her vegetables.,0 +Lady de Bourg tried to persuade Elizabeth to renounce Mr D'Arcy.,1 +Sarah sang.,1 +Kim is easy to please.,1 +These plants may grow as high as 6 feet.,1 +Some gifts get used a dozen or so times a year.,1 +I sent him that.,1 +John knows.,1 +Mary is shorter than five feet.,1 +Carla slid the book.,1 +John fed the baby rice.,1 +We rich have impeccable taste.,1 +The tree lost some branches.,1 +A soldier should be prepared to die for her country.,1 +The truck rumbled into the driveway.,1 +The dog chased the cat for days.,1 +I inquired could we leave early.,0 +Quickly walks went to the store.,0 diff --git a/tests/data/automl/data/cifar10_subset/cat/0001.png b/tests/data/automl/data/cifar10_subset/cat/0001.png new file mode 100644 index 0000000000000000000000000000000000000000..2f12845a3c3b9bf1ad55c8567b903ffd8b00f3f0 GIT binary patch literal 2105 zcmV-92*&q`P)o``=I*t{OrGX z@17hTMFO&%=0(5L?e==zUT-iQ^m@ahJ?M0LdArDpBFkGQ$pugVlmGxAs)&#w9<1-I zk2Y`bzA+fCIkd=$s#aC` z@=m)yTJH~cdi_DW*X{JWdAnoM!Z5Qj_=i7!c=+hB(`wz`-A(hNs$G&Kv-wMa4715+gk|ajZCJCo$(aL+H&CYOlFc^+T zBOY`*Z@=@S^Rv_I>x*uumu7ilw3SdwX> z`^hJN?d0j5pZ=mYg^SQ#-}0P~$K!&fFg9gIP8C2D6hT$M%~f^2Qof#CKKSA@8)WOP zH(ZdS-71Q01quSIa!v-ktH1v8-u;jM)M~Z1*S9zK^2GAb?!3#t{>{^~)9K`DKD%Db zCd=irEKBdbf+AAFd17r*6x{9gj#rDT7Z;`Y&e{ed=2yG5((>WUD)y8rofHVGk06a-LJ@BBiSMFFYJ^H#Sv z*vJx_key9e@4dIbxw+ZybU7cgOXXn<{_>}^3)?#ae_uOCGxHx>s z##H4B61I$a%8OMUycbSpldD)Qd;Pu>t?GKQSg_$mRd+jG@pARzV!50@I(Yc6yPq8$ zJz6dn$ET02WlRytnt}!QFqnnoi4cz5C{ELL?$YP!-V-Tx{jpY(Aea>zwG` z{re{`r=8xQRkRYc2M1q|CzmHDhvUhss8K}$A;zczDyXVN#D&SW_ulk*8fsrVoy;n) zX!8uC#HtQt)ufzSu4AqA;PCL(y)TNqoiM9nS*}E*su*Jd43MfS83u!(s;U6U)qHhp zXRqBKmKWvqc)DEGAp{_l=p_aRI-bpv#3Er?yG32qz78Q8AVp}36%kP&Q~-!k002M% zKvfZe>#AhKgYF$39OHX&be~6a^5@VnT;{zeFz~&k>;uh04gA=Dk7=~(hc-$ zf~E@o{O-q3pB!7(v$7J^ z*<=#EFY?S zXJc<~H_zKwm)8<~LB2&{Z)Yn>^U}o-<-5Fy@Y@EwCWs=S}__H{lTU!x>a}2mpA0Z>{LGZL5F! z?CRBgwpzN&+t*vSwl6PV9G#p!eQ^nB6pf1hpJ7y$2o)4W8Y0oCN&pIg(EN+QtzL`L z{wk6fVodVX%$L)pUmcwsUtiCHXk&O)1-PmDh6z;>m{34fL;w-V#}HKrfq(@RK)JWJ z_2l{d^!a5(bY-YYH(6F_x%N>Mm{}raUc;Qh=E9;6e~lRGcy}R6TA;m z06c&ggIq1;=^cjau_BEq_guD(H6RU(sE$`ju%(jqU6 zAwxt44Ka}gAO!^i08&u~Rc2x)Ycp$7094f&qjSL!ce;hBLofgaDI#h!O@v4{3%hwk zW2_-VWD%JNkx+p^k&(z4lO#6FEE|h}A`)X%k;3NIV2BY(3;_`$5jI1lpf_zdga}AL jhR6^aYU(ivN(l5n3z_~@Y*KQ@00000NkvXXu0mjfG8*`j literal 0 HcmV?d00001 diff --git a/tests/data/automl/data/cifar10_subset/cat/0002.png b/tests/data/automl/data/cifar10_subset/cat/0002.png new file mode 100644 index 0000000000000000000000000000000000000000..9ec18f38ab51fb008df4612110b544ac30b10818 GIT binary patch literal 2193 zcmV;C2yXX@P)k&g9w|MXekKpW+BL8u^O_+ z*f2#$nCDWNQfi263Q|fTLWw-*xl}-8WI{7DB{Bd&bm5??0LaWD;toV$TC(O$;6+u{ za#7Qw^O&cD&XZSdVhobzNgXgXEdWhZRyC`R$nL1>W{AKNA|kpwGrI!-NPjqhpO%sl zfQbnKOp*QeN!R!FX1AZWy~VB?Iq-1Ere&UIGxJ)jDif&I5M-Wb0H|8P3EaTR%xX0m z$D!mowXsPlFbf)adHepIzyG_xy?*oRU;p`E`mv08=bkEvS#_PwTvZWKf*4>aQz<}1 zW=7~DQfpOJRdq0##_6!%BR3@jn7dW?^6 zsul!dmLMDm8S*^G7*rJjs#YQZ0A?{$*&MdN{-1Lp1Zo`@sas8%YE?U2J^%dUaQOZX zdw7da*T)yz_9iM@9+g4mO(u5fU3zkZB%PDK#-kY;O)WR#U@o zpFaKe%e!x$KcjI63XdP$`I`^^{NMiT^WS~-Mq-+#F&Qk6j&E*m&{0xT)xCPrLP&?I zUey823=rinKK!5+S}wb+^`qZ>vAIz}n0MQcKl=45Q5)8?&G#PNYu0%8-PLDLFBn@v zQY$d^p=LsIHLBI!$cp2SXi|)!6kv8~gqna>N3E3wm?(mucI|L=wK!gl0V=?uAFpqA zQ!f2D;z)DOWuA81tuPaXJkO5qXn;Tn%;b(g%3P}octs>4NsUN|UsajQ1s$QA65gFY8tk!B)1OOb!ktKGk)d`%x zj*P1=-fK=nowQd#+%|Bvo1>VA~;)vlae7E#*Exe zW@yNawGv?vM+Qna+xhG7zIk!JW26*WggDW@AI8ZZK6=27?Y!R3%f`K^KHub%ZgI5i zZgre!Hm72u5MI9CTA17`P#sws;P~*);&>(0ucpH7 z;+6vmd1^9HLB^BY$8-WwA~7dwj+d}eJUe@E|Iwq>YB3Ero68|f=KyWnWmPTZ@x63; z-93FdPsOylxXIIJ=d(h)TpM~ysZFtq<1k!IV{Mj;X4zV)^E8e*ua=9G$M??e|0Fcc z_2x1%xl^}T5|OG7eLqh_C-n2j_g`bJlLT zk{AuJ*4nmR6YD%pCvmOAwA);R)94OvfUw)`*GI>~-Yr&bn-ColmWwW+N6VA>oul^r z@=z)OfSjD(Cd5PEhY(^+ecvbKCdH`m)?6}DGJ_x? zrF7`Gi0tlQj^NCJ*csi*G*1pdvbnxobjv6W8)>NyP>N1t-P8%0fNb4%k<)(YjeNJ= z-g@V!u~`BzxDVqrO|#DP)c1Kf%)?=t$NdPz6oeVo07Q$TqiUf*wbm44Ie4X@b?;-( zAuev+u4}&;vLc1|slmn7?iY9O#P!`bo1w#om&-r-vp;`WvPL}JGQorw?z8GI!?<+DAp`pY;CC>%`1X_waRp&Rm{=4h)2M-_r>}L;7 z&)(W>x9>cDyq~kIy4ZBzy|_5+ZUFMC;WGKqk0RBU;bt?PU+q5q@`tN^1{4AWB-w8_ zi`A(a5VIBo@LY^2{_gj$2Mh1C_dj{|`S;Ji`>VffpFDX|JpAh4|NGP5eiDS+hL2Wl zzr8NGFo%@dob`v-o6GH3FqCRc1`cxh`rLtIcLYFyXw}TAV%m)M@fR=t_lp-qB#iIv z>S^!)@t?o??Ek(j*%;mGfOJzq!4a`e(-wiW)(Xj~I?mvZ0QlhtfBNc8e|G;}j)_>z zOiK=ou7wCHxCj%#(b2j|>GEoe7;AO{tU3c#{V|S!c`if*I@emV)jUlCjsOCttJTtK zh3W>Ts!fP)X6S&#;Edt`n_VVUmWagUfKCWTButR!?0}I899VEN16L$M2STDb=cet< z0o+S2gh)i_9z!H#CSeo+4$=@ubE~G+)ybUPz`=|O10sWi)gPZ;ScnLai0JJJkP)?7ok)3N_-|Ox^eI9%7@3k28oz4IY7fqp^9;<5<)r-r zix4ibS~Mq@#sUDSLNNAdr;MpV#;Kw|P<3JrjSDhL4>oTLgcRjBDTB2x^`Z)|9g$x^|bFek28SO`J5 z*rMz3G8CDC)&K#78v4M&1t1H#z2+m5rp_IM(Tqk89dgt((ZB~=-QKA* zUY#j+`g+H4;+(($8U#RM0{{W;03jiZs=z%u6OC>T8fc`60liaV1Bd18({Z*<7sBNl zXG^=;>*k2ViS-0JVNio8t97cpox{~axuGQ>0}U`36lkyj3owT}x&gzggxJ8mSpL)7 zZ#}mdYg1Hx(6?t7?R>LmJz`gIS2-1yhQ6e^?GHE-Q-v~E0T2L1!X-+e5y8YjbFVd@ z#$S)~uT=Nrahcm{uuZ8bP3vKiE_bJwU$^4|@r-uI{G}b`&kPC*fxKW=UnnfZGJB$ zd)pIAVtis6oOw;dcJybofO{=s$_@}d01b0)yHkI69Llt>Asmn6^-Z1U&O?deg9%Hpb9#SqMl(f3{aggV_Fe$whAH+M9iK02 zO_@fb)v7d3{m>Jr(wJI^ts=XHV4)JF?|!wc!)f|KOK%e!IhX-so>@NpQ5{n#6@xeB z;6Cp4A1|)|rNcjuH~)6}s?;)g7u~zqsWtO1tGc1@qBC^SA>=gGT>CM-GaY}_@DW#o zTC+7ZyY+qbPA~89`E8r_b3S(K_4hB%z88i+wbjo?{eAoLCXF#!=V=l9?Zd0Xbee9j zJ;piBx#Yl5N{hkDvir1^IG_Fy)4Lep5z0=&hYsI(h^G&^Kk|zg*j>-N&)oTVv;FG_ zZ=UV?ze~?=>YPehEmz%Yg>jn3Id)xZ*v~1?(`LCG7K@-}T>h8JFH?HP1>I52Q0u*e z)5*(i{O%vbKl;Fb|GQW`m|neNy8er&-~7qh(?0eF2HJ7A+wE?rl1gst)ymz$F!Wu7 z5h3(l8kV0BUl6n)!BCs!f;1tG%o!pVgFSs>-+Mp4_r5d?oqo7Dd&UkSi0BW8_ojM< z&}xez1Y)@yViy~L5CE$deogxgGJ7IMM{pRxKtL@65)C2p{LHpnKR+*S7jwO$ebo(@ zm2Z0Zvx{fH*}eLQyxSkfG^ZG&n&7SwQm*f+`*we|&9#}Z>K*PNfpG8fupp20*o?5i zM^~Qp@^{a_F8PCt%a3gH`FMQu(e{s*+uvUQ-#;pqRATVfRE2&_->==rSC76ktcK%B z$2*4XXiXMKfdTI5dU_O~S}kYff@*{GVLbbjvnL;J-}vnF=cioTYV-KC@|3Dssdb#@ z#}htUUjFdmlc(p~;8fY_W`WyRj?mHWJ=N5A^bWPDB+CT?%+&cepFI7m&1Sn=_M7vr zmY?5#`JXRe+?I5?THL9mnm>r^zkKk{`)|Ir4pfY26(r9eAkaaV1)3o`5Hux~gi;_$ zf(6m~v*lH+NC|nrUJQO(ZtnK)`r?;)8e;$PThIRd;qSe(IcrnWX=mM~g0p0u_UH#k zt1yZV0VyG8w5n+miJ>Qp=0X}lrm-HrzKH9;S)A2}=l{FAF4%qR@zam59(Rk-GAzw% z1tBsgT2S6I!2t@S5k{q|!-=`tdVxN+)*v!ELtt&{v*lUGuY#}q?EJlpH~!Y2`Fd?_ z*3`@gnXozsNlB}7!f?3u`{UE58Dqk6QXnqYj2!@rKnRSTLXa93J=R-}d*$3LhOYOn zYoS%T!B|V{H(YEm9%P_7wr^lqIA|1cy~kn3)tPNqB0)ni27>_v@5eVty(#uA$(owb zO0|mML4afy0`7BjgPj!ab<9v?vBRS^)&qEtD*!?hBuHYTfKn(i5d8wdc{+vTEkjUS z&>}(2GlFBWLaxdQg#cZpsWc1`Lll742qT)*UY|J{3~G{q=m>_;fq^hhQ=1bnJM}$O zWdQ{HL92rCK$%D{H+yqh=YY_6u-C;*?f$!xJ3*i!b{KlJrrOYixa8F{)CwcIo&gw7 zN;4K~dIXHBIvxG)u-~kPb{x4`(tzL!0TNoHn(p`doy zhqw`OhP^ir{SiwG|K*>*$+a#ecU`}~xw^jIuU8vpiCrftM2r&f+0*xTo^*M0l?G8n zLuLw09Ebu^Kx71~x}1)!7K=T^E+Vo!ORE*!!5!R8&CL*ji3un$p<@G{m0KJxqun@s zvmdN21*8}(xiqb@TZu4{rxflQ z%TnfP)}|s+LPBD5`r+5-fB*K!W0N#&gT&Ra-fVVrt{*<#pXX&5y1-uY3~t-?u-o*j zK8UIkF(9}*5eo-IQgu~r4ZA>A%frL*vE=D8{(OJ_?(O|p=^y^?ujk8{8yqh6)8m;W z4twmnbUuv;*7bd4WNh9{z@gO!2#x>%08XI7EDSLqi@g8v`MdY`>n}uHS<)f%zy8~Q z{l|ZPyIyx3yY0>Gt3UfIVNrlN=h|9~LD+@Sw5h5viE2Y8L_|U+jnZNg2`w^>OaAtU zhr2`OtFvX@t%v2si@W=W?;msb{Jh?FfBV&+?{9CdREBYy7Lve3X0uxKY+`ZPa+Vn4W*XW84_rLzK z5&He_zvvmzvDO9%U5W?@0Kx*KSaNg3B`?#wxMPofCHd)LIV@O!wAs~e>oyL(sLJdi zVXAd;ZHFf~ACDV>>6>*=7z2ZQm!dGbxw{LYBZ9fRL#ef@vH~kTo#&5F<0ryJZ2MLeAK79Z7|GxhE%hR#OPF}qD@^r%c z&+m`~DSrRc=S`okdWnJ@pgI63AV^-cIhIm^tsKUoSBZ8%FCU+#VqLA5-PPviX7A*w z>tEfz{CszJn2XEx_SzmFpO2^W)$ZC1KHNQb!f*Dg6aYC?6+{F?anq$}u13yFUUCMC z_&9ngNRdgx)&Ayse_cv`^XfN$`MbY)``!P3jesF8rJT;^^?Ge?rPP{us-B&zAsMm( z0uZ4P10n#)Qc-~&lR`J=SrvzV6=ERb{mqNZG+W##)8GH&KaIzSAK$<0SN-{X>2HT| zoXS!#q)TZTC?GQ-aWzv#5&#HEP@!Zr@Sek>m>V#MVYNzK=YYAk*!8jBrmmaI=j`r({KI#L^Le}3t~b{);p5Zuc^nZj#zZXUYFhKpFX#I z*{p}0=M;!35I9&jUoK5G7x?AleeYxPDc1!6h)AYl2<_8THqZ=mSn6cLo;M+Q%a_xV zr_h!ucKr-I&&!i%#(fm>R!Y{J%`PRLr}0Q0M1osu<#aqP^Bhwbg6L9FI7aRyY2=V3B^D+| zBn}uuG>lV)+7$C5XpE?8!_bKcOIWW~YF4MI)mD}TTAk-M1lbOo2yq> z0uY<3c@qX;f)EHXnDHfN1VjK;ol8qy->o(w#uy`EXqrtSg-&`tJU`A$o#w*KS}R?~ zj-x5`{qW|CZ4yMJTAC0jJ2E;)Y8t>Z6EO=SyMwz=^DKX&R9Xl@nAKd60GT9&KNajY z&$ChMSMk@+$D3Wc8khmBHX#53?GrQst)vFV0K{q(0KAr2YY3ucW};zOB?+zMQgUl8 zihyt+l#pfwE7}ZyzMrHU-|YK@8qmZX)m$AfbD1(%wPudo#Bm73;BJa$XpX?sc#2)$ z)JiFlF~#8Mb(xb4FLoE&T}c$(^doBBR5s;I-Fep&pm)a=m^mr@P#5Td)q5Nd5? zmN*boF70wURy9=}wtWC}fo06gWV>}Y^y&F-|Vn&#TFdd^wMI+SUiR_+p_w+imoT3%FJYYyg6n~k|ln&+I?K^%|>-Q7&Z9g6x= z-6=?j(266Pd1-nWXS8~??n*UlO8{gd1hdv!t%X8LU3GM%^L&}AALhlprR~~O!GVa7 a;Qs;Gn%gRribT!;0000^Y4 zz%tkMN0ZTP^!U-e`}^Zpe|q`$?R$&xbUcUy7dnhGuPU}xCqIAwWipv=i;ZYDKzA5f zf&gTMsNPUY5rhar&N&3oTI+>jJRT#AADy2&jJY0t{_GsI+kskl z`}9{&B5Ys1{J^P-O~p~^z+M$qq}!TX=GKav~Gkp7Q)~E z^S_KnLqdq|``h_^I2cZ+lcsI+EN`k}5IU2@n+}rO55L%5?jf^UE{nEvJ(h$X#E$I2 z?R*)IX7T=UV_+1=x~q{@rZ=*O*NY9gz5GZWNI3--MSfUUrRTEQ-ZZ;a~rNG5qFlzj^-rsopJ$EORhGgi5IZhSc+j<6%O)@xk2(=NCW!D0W*ypu=zw zd1}{AW~1e1Q`a3N#G$aSTlDbJ;banV7h!C>wy{Q^o!)Pnx~Qv%j~|DA*!M6VOf+(e zov1opi9VVgojiLnIsF{afKtkX05b*%L5#*pG@VYR>=8l;LC+0H;PKx)eXv>=H@6#% zknF^F-~H#&;q=M*v=aw+PCnagvhj3cFsiF!n50VTbpL34=hPC~Yt3EmjYmTEyRwFi zaX$!xNGsJV$Mir*Kvs+O(ed7clLO+wx>nt`%ZmK)U<^fdb$zvWa2PY+Vmsd7pV(Ov zg+-o2P8oA_*Y!oVKE(%#yzVBnknd@R;tkw#^R%l`Ye)Hz?X1uQk1%`B%-9KF3{ud>67#1;LafCYEvQtJWhKpjL|~& z8une&x|}iujPClrBg}CKfe<5#l?8R(cwXEaP_q^;57DHRaLeO}1WV zf$JbGwu>e1d)?K@ww$9h9S%p6WH1bZ2teF+T{1|+B;D3+UF{h6v@rl8YvDStI*W@| zL*g=qP;c~L6b2!i&Gx-u`u)pyMgh~bZBui`q*M^$x@s`Qqa<|@F%~mFM3h0o{5XY# z=)ToT5Xz(lO{a^Z38Jy%Mu6~Q*ONSJ7P&?_aKwvCt3V!EjXX znmA2QPEY^x>Qz}))7f5ctq_t>N*!3NKV7_kjj;)nkWebRPKgd<6a+D{q-$Da0qEf` zA3jyhuHLSZ=|-vhcfb1zg;6w}O~=y-Wh~od4<0?dcXHnkgCLAtFM!yw79vc!6JY2@ z@rRx=h=!Ecm&%qWU+;u_-VS6nI@tB*H8TBL8>+QBc zz`ESk+Au$4T_a4tDXSvS5^o5NW^QzJ=b+e1!d?9J*Dr*Uv&n2UPJj5p^(bb3M4+aG z^zBww7;s+J#gDJ9CZoY*k{V+i4B|A=N&<`=KOmxMd$`$lS2rbuD2_Smb_X7ES%~f* zALgt1-gJ`Z8}8y4FTT>+wz3b>F=f0LJE$8*EP>$q>Si%t>b40S;4|6^Z839rh-lbK z^64tO{J30Z?d_^vLTi1-Q>MeRn_EKUT!z1r-!F!hh0-)OzYN$BSPN2 zc^wDt%OC!`+=(!W|M>F9X*$DSK1xQby~%Jm8gH`o&GofGWV2e2lXzEGeb-n>!q~rB zWUFi)1SB4K4tIv*G>AgiqnjM8S7n-vkp;vU#aw=V{%F1`ilT~wc(Ke@t6Yk1xtKfS z!EmzYb=}90AFS!Mk{7?cv$iMH>pp$zWX~DiwYAdH7!^m<^rdWo&(V(Aa=W5*LZUQQ zqM6M0{_3}1zx(;(fByKFo2#F$=W8XD+LhI29i<~b@Rd@t*%$z$`rd%PstV3onsOoa zW~+@xzz_n7P|7GnDB7FMs+~4HAisY08`*Y?#j-B*FP=S&2N9t}Huc-r9}n;PO1GWR z+-25i*TazXJcuAPR2yKC>XkwLAPxv(2y2Hpbz{rAR6-PaZGgYZOR`;FZu5;4UEW@q z3?$*;cTbbHA{aby!F5{90;-Q=r9 z{^tGWdL^^0CKQ7~>~6O^Ahv2#gD%&*GNKr>LDM#2a6la=&x&`i-#-81u~a=mKp8Ee z?WOFxS}I{_>No&dqqTH&>%(1JX4|q}W?dX6epCA!@pjrRArmpD&8|Lp$( Xxyg(8_3XHD00000NkvXXu0mjfucYl{ literal 0 HcmV?d00001 diff --git a/tests/data/automl/data/cifar10_subset/dog/0001.png b/tests/data/automl/data/cifar10_subset/dog/0001.png new file mode 100644 index 0000000000000000000000000000000000000000..cf8482e2595df960969f58587af88c07629cc149 GIT binary patch literal 2457 zcmV;K31;?*P)GedKbW(o*_H9)+1R@B4r<4-TJf;vrW6WqYIy*btZdY1$ z<8l9TdJ%B8UK-?*q?Zk{lrkP)T{mUbN~ImRo!!sw?#jAC5T)XH@#M+bc+z8$?G&-V zJMV)yqCwywvesgZ<2YuF<^5b6874yHS-^QC6^n%sN!IHrrJeKpa<$s*R-4`IVPU*W z`Ju6=&i z6S05$`t@I4z1i*$9uXcT5o3{vC!^7Bxh{^UH}7u8JwXW}2qPYbm{JFUHU`s!?65A?0E)+?|fl z>p?#wX_k2e0fY#_FbtjbWmWQkGseoYu3Eirkc!ad7cUGKMR_E^e*5YNv%9PJdx%IB zC$UHkfP_HAwh)4q0$B-NGN4A@BkKdVUSg}c{m)P1+8WA=nLNe z2zfwSQyqSko2_t=lA!brh7bX4oIAg~7zBb+y56ljEiW!6XFYDUB-&casP`TtgfJqM zXB^kEo88~l()D=YjNw@hVnG~?WL5yAjU$}LgVB8Rcsdn~#@8=C<2)$JGRrb;I_GVk z2*!Z43N=kbDIJf;1At&*khuKx3bFS7|BLNX%j5C5?X*eq zi9_)Dr=NWN_1BBL8G*<PH8E(;8iwq|m6Mk&3&yQ@w`Q+J$&D2jwOX16<3b0YF`#suJke4TjEBP^ zAvo9__f^w4kDJbyjRp`Ir#oE@MZe>sER8vB)au<~J`ePXJ-vZZ2!Zz);Z!g^W*lY% z(HR*Au9t)aq0960v&rCcIvHdMY1^v~e-#sjj#}gsM0^hWy5)o!R61>{pzwesu zdb@mhldV4g$*=w;o+c#EdH)Q_F6AJ}`@N9rrh0g|d-?LavT2{)+#oO4iv_Vp8>6$*Q%(qoVDNAWw+xj%W->cGYBFG9A`xI(o<3Pl6bXQjYos$&pw)`_tWX+ z^z!=Eci;clAOG8HWqi|U6e|N8RbV;z@|QJ^x3l&3|Ne?Vb#eW|7E6NNSD!uy)NhMh z;+<=o#(3+2KkZL>KDfEL+3a={laHTX#}GrzbgR>}|D-=&Zyp-mwutp95gp5=f|Q zYK(mtg@fUMQtF(egjQww_J<$xem2Z<3lzm@dO6)4Ps`mZ?I)k#j7q1MiyvT`(1=y_ zX}6xQ-u-R)?&U23!+@DIw4UgwS|0PSbS0TyGCMWz>^l8p>UBJm#heECMwFsu4ZZEljQDxZh`# zC|&c+d8iO!G){y_QyTEDlK`NL^9$##(w$KC#39>u^^qZTbumqo5Wt3VWnE2-cX#uLB#i>c?&pgnN(O`eV9>`H=tdc- z4*P@0@cXxKzkl;L>(u!C$>er_)$HjFqXZoW7Mmk#lW006Wx5II6%tL1vVV2nL|_LM?&J~^Wh0B5_pf=V+F zfOYk$^hRTho4P=R0RW;jaSm`EN>eMpS#yyy;ScrV>~v*Bqp*1ks_oZ1tcEel^0(41N;;IMWjf1OY&qTQSv~D z!EQ4%nfCN_w|l9o?pp79?%DZ1_?zE8lPD!d?W0axkR-~=x^dKXqKr7!=|@1CMOAOq zdNk}h4>nKv{c@A`NiU5O5O^Cvh^e70LWo9fZAAO1gDnj>i-RuPC4=BFj?QE`kG!>+ z5Ki=uhRR8Zo$&=qAWy6PouflN+?2T}NXFx%?w-zz@t`l-rfoK@tH+~#N;stOU~ne% z?$z{M*#Z~}Q2 z&5mD8W;4UZXuX}SgG)_=Cv4+t6LAZVg2 z3anWJiYX$)wATQTx8=$#1{oO-nQ-$FKMs=M(;t1*O;3!F|MvMmy}$nI(+@sKyZk#Z zPs8L_SHHUmkPZBR9H-9!BlSu;nC!=sk3M|KgVghQvAw=pU0yw0Is)s)UCpm0@hK}d z6*ne(`SITl29q@HA5V_MBuK)i$<=r?-HjDMi%a=d@jkt_g=SY zclg=&e>7s3Rw%%vG`S7?6>&$sQ>*m+{@c8e-Cng>&oNYS(qoKep_UfNkpkA|5xmIXoUs;beOPO#hMmzUQ(n|%2E zdl%n~$?Nws=Mg8(k^oUwyN&4#_dP&C`!E`2R^=+6Vrxx!AX(D7WLxU@*VpIg@9IXK z9vn<(_fM<#@c5|L?@8<5T)h7D<-{6ze*SgZ@8e{bloFb1<&2`7V}j#>rUKN(rk~QH z)^hz&EoQ7Ukcq0wVx{!cM%>&y6+#~zO{f)bE-n=u+)V#iZCC&Jv!A`aetf#W{q&>p zX1@67lOO!UXTK(m5VC2!Gcn-`MZ$+$g$(Vk+I)8ZZhe1Q=5;#BP^T5(ant&|62RI{ z!gaY3Mjk!?UJJAT`uyLY?r%QZc zwoq2Co-Oa*&%V6`B-u+fP+R58qLE5@N&;Q4mb3S}!|`N1c{eS*^(+kYvOYf8`N>az zzPkP*pTBv$`{wxB3xdOxI7cwTG?E%@^P1omnzmgnW)BZVDKR5tb`MjHdFC86%3^2R zw$awESJSYO`mq7qWvUg_vZ=O*lRi~@NKac?FJ`wyd&}Z3p%I{5S?{)M00H!_m`#_D ztF{%8T4f)Sp${@+L8OeMh?uk^DlpYdaxfd@a);i+4JL@ z+e@PA7BY%SY$5g5tETqA*tXiPZbe-h50Zp=;ANd_hiF6yqOy(mnYW;q(YAP8fr8|sns##jJ6l&vNR1D29NS=}nxLTeH8 zkkejoI2s=WjH|Y#JnoGc!`PYD>6T#Ryh4OVLFOHxJciH})rv?ZntGjt5l4|Rzfu$^ss6C333;;goP1g<&lT z!69cX2ssNQ#^XFMNLf6NcTc3&P1AVmq;?Jv?~F2Q8s~8*(?@ELLJP z?57DU^Hr;<^;jFBm8=jGeJ1>Ry#gD3&sVq}9TNh3BG zjC#GH^=i9bI6Wp>Tc@j~bByjo!mGS2g+d5eqbbI~r^XrwNKp&mbf;{>b%QX&s8Ow@ zj7MS8HWjBh2_tBorC#oOK?8z>xrlC8(_fanuQdm?B8xm^#M;eVuvNJi|=FR=wAOc|=20WkuxHwLo z2i_Y(0OyD>5b_|20s)nFRw<>mMi3PB_N!Na{==VMUEfbzB>>@r(Xgy$|M&HOzkWZT zO_7lO{K+%M_V!OuZG>^q8Hf@0P>#o47Li&B2m!`~di+$o+!%zMv9_-3>Fn;0Uwn0Y z_aqd!pDsnYgho{5s%aL(Lw2w?@IbIZBD7gZ&_l>h($ literal 0 HcmV?d00001 diff --git a/tests/data/automl/data/cifar10_subset/dog/0003.png b/tests/data/automl/data/cifar10_subset/dog/0003.png new file mode 100644 index 0000000000000000000000000000000000000000..6f4416fd482b5514e7b48f6b0bcb67e13604db8b GIT binary patch literal 2416 zcmV-$36J)PP)~z|hk(m*Z5gD%=aqm5ky?fH@BNl#leEj8?|NPajet=5y*>4QU zI#z%G%fI;BU;G@B?z_d}AOH2;vo~dDP(OaK?b`g~r(Z0}-M{?X*Jiu^hrfIJXP-X; zOg0zee}BI4caBt~ptv>cjz^=oTu**@u24yqr$h>vApnYMTpW!jfj2B0SJzVcjz~Y$ zSuH9yJHARgy3=V1aGF>Uk>Drm>igq)Q?Wo`=HP>OK7{C;CqyZVtlchF<<LsosOc_39ywAPOr+h=kCS^?F|tC0AOKe z4#881z9@37z-wI`J?wNk%I{Y3>c`P#cRlyj`=34G;)dG0|0GR@>x$n@n&ozMyWgy` zU1B|8Ac(wkGU;`kCqv+vE8h7NObI z_Yd;vn-}w|O_Lm6UCkeU@~B8VZ_kd;XA8{tH=DSte5Xi52tEWQ!ci#3NScx)&09s? z)UZ(?R7*a3cmCnc;{EH_q3YYP+eV2#olhszi?hXiJ!=+?1FJu(O84e$SAk7TR_(Cd zy7rC42k*V35Pb;by(c1L$k-&J0zfiErbrVt9Q3Et39529IQ)adN5lQ@>hg6`bUTet z+v(0H#?WG2y*=MVN>N+qbh)jS>L4maB@mf~qOVa~Ln%Z~T3thEltx@d45gRFQ)|`?NXe~xKiO?}q0Tc)-@LhgGj<3m2MRH4O9ucPodf}9=0F@8R7Bcva0)~r zAjBB_WVK5Y)s#Eb3Fym%+ZRfkYC3J%L?F#Fy|;I3SBI~@IiHm#$ueP#;+r}I*MM>a zqCl7fMP{Xd48~CTYRL^d0ADZ1HZGri_4RPrzm@SpyC`~hOlDPxFJ7HK z|8C0CLym}nV`ypz-f5~;tbt>UQMtw;iZM!|3-z^xmH3rab+_GWYL_O)DlHCvczMF2 z^H#?uZJTzzm&S*(+L#m(QA9+70BF=)SI6HzUteA9b}L4p5Q8?9q{b$Q5K>DTu{7Mt zbWW6Wl?O7w!+TFUy_>CmYLdJWPzr*Fy`H^$2Ywtk0c~Wj5$at%y;#pL&roz_(xTTT zj!|MX)@W3TR+c2QR{C}|9ewdvpLR{w!?e{hy`u-F+b#Xh60~&W7&q0X7_>eZfYG&= zT);8FVz#-~S9#4*D@ZX(EpY@AAVfk*EaaKwNz8P%pWpc8>E6-){q1CVHNR^2RR3^q zGG4rV^>Ua{QWTOQ8)~J{DeBESXl+FNb{pnr^P;|v2_+y&N|dMBcC#vH#3 zolK}_X*@dKEaz@J@3xK73Hk`=Hd`9M+y39zzYSH~dvvSl4FGML7Oh^d*X}6f$YFbR zH64$Qf{#BOzIwmBDjOmdg(_|%FLHJyP|qEWj-Q9dYwc|g>;q}76RejwSrpgH#rf#+ zyNlh?&DL;m^L~~#nj)jboh#?%^pY!S6~kWdaL`$eH`k;(!dXYByU!?GU4h z+!`D_{+qwPGCg_oze?jT|ME}9uU>xhU(cTW^uykM>-^-j+HS5d=Y=*Jb+g`(mZJ~s zXHOpe>CYaJQJbUTvA;ZfvC<(b0NR{P6iYuMhz1b=y8raaou?0vdi!&;jI(bKyY6Ik zs>|@>|G(mt(fL8vY8si*JnS5l4lCVdIMA-6eBQfJE$0tE|8#MF+>A!3kg_NMY!o0! z5XcV>dLKVrY*xcow%*jc(~Hx|u^;aS{Z4wjyV$P$ro?R!AAHfAzkaV4m7589xq9*2 z>yIC&y*BF#2Cd=Y(fQfPD6KFCtq}-?Ac*WedT{5{KkV-9N$n(ra=D6|Ya2WWCuCc? z=<9Yzw^}Nrz^89c>P1OAmpSLJC!^!<)|)A4>Ff>f{^UceNX4K~DG`jQW5h)0?8mp? z|LQkwo$V&;<=N{zOLQw!gnnJkzCVp;i`+#eKpGWz4-f#fVo|69xalkr|KQ%ey>?%b zMq%>)_ZKN+3~I}MescET|Mca_fBa9TO}#z)qtA}kNm`CB527zqPm2aN%N_wNPO-ht_MiU?>8Xu-rJ z?7V~+TxAf7tY~#RwspN*Y&EN#3|95|v#*&|mgboP%CZ;(88YmG1{Tz!Oe?Xj9a}^m;`aHKGV3NM?S^xlo i2m=W7wrPBfTIhG1nBuI(11E_90000v)5KU0hu5jj=mQ1ftaRdf3SroqUT7sxUo*}w`RdEp1}+OOfAZHXO{_qJC^ zfYh{vXiy_J(1_NBzEjtnSvf^!#&F_HUif>15C8R>zO5S5SY@$kXf07%C>)d)D(e~x z#-T+tgenJK9QQ}VXYZUpJsBU5yIIS3lmih2kq|-%K}6>qB0&T|L`3|zZ-qoFYm743 zDAW?9vC?4Uh;|^IYn(F<3?bSCqQ1 zFZRacwAJFAQep{31atrZfdBE`gga@C20|&NV6@fF8gvS+L2%YOtF6_9Ng%5R7shL} z0|0QKjaG=tp$Pol;poZ5#l^+hc#vjkjQ^hihd=s3+jliMEj4HhUj;Oyw>#rg5cLA#T9KF9z3eoAyj zopKHgQUjC{qyz~;K(HD%TD)4hPj1V{hB+@Hj0ny=!Kz{d&SjC;vQ6An0T@T2p{#Kf zxu0~qN9U*K?_A&yzn2)Jgi|rmlyOCUY8A1{0a(Iw?fv#%fBc9u69~>CYNCL5yWLi+ zWsI(>dR8nWEzc?G+0wX;qE1^bEfy9Fpy9AyfmuU)xDzmwXV9l77C`4-d@Iey^Xz;nmf3z9}B2^Qe{X?skJPyczu5UjNBT*aklV5(s`Io>JfbklYacxA4U zLJ)e45zk{n(4r`1x%FDDey@A;@c8EHrc%~uSvK|2!2zdYxmZXI)8$Gwb-x|{*ZVL2 z_~Q2OzH-Rpj;0WLA~2F&&&q>$!zl9SXp5>SN-hLrj8Q_kpmkZ6n^oBB_xk?<<;^p3F`Tj7=JbS+1W3{8x*u?86)&&uU%a_lZtAOt$JdjI!C+-@NiBIbW&M^oKiE-mwG-1+)VoIs zVXm&l`7SGKeYM>E@duyEBG*Q1V+rS3!v6AU=b!%mYv<$c<8&5=p(l(g&AVsMgde=U zoqhQGn>RPh#bTXA-sSn2#iUf*$4FQTmV>i(ljp0t-i998>4jgoJbbuaeDLvIYdG$8 zJJ!z*4VV~`f`kVo= zD15Jy^?J4xL0mS}QPN4^rS0}$ zhyctJ#_51ICkL&={Zo#}tR;+7Kk$vS$e1>DqnDd%mhL2b7pEllYh^B<9zVZ4s`FVy z?b&{BJZ?Xf_5XbMyU%8879gRrH}24oV8j(C&I>4Sf%6f40Kbu(d4i&|&gjizNy7MO zZvbJ);!OCyR;JvT$@SHHCmmI8{xq3|ozd}d;G|uz7QgxJr%(2i;rPV!V@8OqNGNb-T^8zl28p zZo8=DCvP6ZpMR3<44$?7)-{|9&e+4lbU9yApY4u&ttg;EGJBsqO@^7a1Uc+~&zpMF>yXOL~Q_p{%8HqV#;^3VTxaoCmloKTy# z(|nOP+uAk|WI{D-;y@=^WNkB<-gI|&`UhtqA}Yf5a$==*tTiG@he!M9O5tOFIAlE1^VQ|?Ak8vu zu@(7Wd1r9B-$JXj<(yLlYZ&#@XlK-ZyqhdHHA5s-RYjbLC_^s*bopX&_2-w>W^I(} zCn+OURZBuWFNmbfSF_jO|IT0eX?K~glQ_zJRAo*Xs}(TbavUACUO-1-B;xq$Borx0000wTlSZnRI_de&2zV^*cOd1P0c1R?Pgp^h-6-^(iR-m?kNFs#?9zgsryzx?? zRM1zXwhU@jwNXn$l~6)W2o?#B?ZghA?r_dM!`^HBzW#prtsniZHP$)Xc5TWr#guYF zpfJW-D*zCYb&i=qHO3Hzq47=M_bJBN>0%g0Rqgt&ZCd~cV@N3?8!@ibQmYsez$%jK zdO46`1kOof7)79z%*=h?GBF`5NYnJzHq0yv$aZ66vuV7lmQqBts*p7`#bLEX>xoE2 z%n-);Y$mE~Y~x!-QcwW0#tA~s<#fKZ);QM~V*pe@ssbXaA|R@&bB>u;!!V{m%nE9q zD+(cmoU_SV(b_l+X&v&z69x6HZ`-c1O|2D?YOShGO1X*vKt(CF)KY6L1PW3`C8r#6 zs-V_dV;GP`Ii*BICdEW(CjBOLe02P99oBIz^VQgO?X;h6PC7ysEv1x@^LiYTO5lGq-?~P?U#E3v_cx!8G*PaWJmP3qrJU&?=Fq0x; zNGVGRrNo>|%9W{MB4Xcq5o!C5&>*mHe2j_7I#+YaDMzzdEMhDwfCS?h8aw^$U%YJn zrf_-V^`G3j@teMPrDV3&``#G8I9-$&ixCB|o#~riRYfGHx*k>`1Z!tC{v-|71p`zKhSRC5MO>8$7ky_@f1r6Xj>l+UM))}uh zr8tHN09mqiR-_na({@e;s?@dzR{e+X{oV25aC~%n^x$FVJ439H4jwx=IRDu7cW-r@ zX8&LpNI)fN$SFy!BC@eD+1crlra7fy9Ah!6fCxs2327WhW-MIGCKP zwPcJnjHudhx3e?x?eV?)_hSr~FI`lkyLWG&93Q^@%QvTe3vNP;v&Bg}L1HQ)6H?c< zj3$f{#-WtjbzQ$^l0NJI#bLY#Pb;ZvV^=JJ)Fzwxtk+Z*rx=AGrZ{^Co2 z^Y&X;fBB0yw>AzQd;E!=N47JCRH8R_b7Q-Ts49_>f-$9>!aNSEbztUNt0}4a##^?8 zlxsPE@Da!6+OOY!{`o6^{^?I&|LyhveC?;}VYvOlogDMY-G?O7`{vB~-4SAuR8bLG zYALnmq#}yQh@`nn6~o5#?F3XO{UpcYth;yj?sAq-j*m+z|9AE3tN-!ptN-=iw{Cu* zDfiYutd`yT_5GOg+4K9hbF<|M0E8F-5CE*Ph?;YfQmwPKRx^aPNLj2;R0Tmp87tcD zvwJ^z?ag0Zy?X!d-4p|%5~C7g@2hQ=Lm7Yj;jl>OE}RoE$iVDlibS*?M=9A@27_p^ zRxwqoNU2f?!5CJkv>Dxt(PQg{a@$2}ciEs$uqR*R+#Zs>Z|#Z^t{!u-rVeG2NQ9 zz9}W!wkatSu_~~FMFWt{&d$zib*d^cg!OtD)o=*zEs{`H@}{=v=L^OMt?*RG}I z!01w{ISmA0EUB8+@hQZjL*CflAkbI_1tmiwl?7|ATGgypYgHY_q2{u+warXntna<| z>;1i*PyYThue|uVH-CQh8(+V58iuMVkD(T^Orl7nIoJTH^Ua#4*M zCFGcDKKtn2&i-!D3(O*7REiV;#gwuv)LNTTLPOhoXMXSC!kOK3^ZBaQQjI%% zc7HZMO(_Et5>~u*^ZfWh-SH5F|(>7GPBLOs-hA8gQuUp zbm6Jl$-r#LPtNW?`p9D+IehEwEJ;hPC3RibcpHYb(K>6b1%*0-8e=E`%6Axh%Qm!4X$#vlCi4@S{y*Kj|M23aiw5M+UU=b^FTH&2+B@&vxK&C)L}oU|Fqx9o8^_G^ z<9XM2ookJvDuVCYloAnlZCh(?jeq>&<5!-1>QA5hWZU`W@nU&;baHsF>-uM(z4GE0 zzwp{?|MRVHeup$e$+;A3jdd=BxccxOv)w<~fAqp5#&YXhb2xv1XeRw;M_fwLDh&M8 zM=yQhFaLV)%$bLWht|5Tv0wY@myZsQmdlmz+8Z}+eD`~QAH&Kqs$|o!b*6wIfT-j+ z%w|WI{S&@*s@V{l)_1j56y4m|$RS$dn8T0%{a;Hhw{E@Pc<;J4#&~#mh(z8sHOu1{ z9+gzr!zmlJj{8ZA08QhG$r@`6w-Yu-OA)ibcfd?JX9Zvovdn-hmDVlY(tvSDS09ZuJY^KX)(VC@1?@Zsd z!*U@4Rmj9~Y+Kj2O^mC)Z&3jRa*ph5PC-gCq~ow^3~5a{ODVar9zb##ExTH>Xqk3X zL#C9HW9hoC?%F3mDIj7AP*597yF1&~7y~Luhzh8xrNj`aH6e_{ux5m5-!|UW zTE;N;eb@KBMGZqHB1H%xu;IMUh^U|f&@@eqiP;+T);g(R(in+!V{@8@A%)Sf4Iz-> zTw<1* W(DWd}{Fc@L0000*Aq5_o zkx>SliSgd+?Yq^zRdr6CJ>Q3a`j0=Ws&!2KFba1@>3Ep!_49}_h!aF%y{WVfgMN=V zqm=5pjwjq9CuGy>mnb-`G)0V3TI6|_WsGsDbif1=&`R5Y;pMBB)p9xH2#@%lD{wMt z{km}>;7D{@NM)Q~Qp_P(>nVXajz!m6uW>gb6d9!w;EmF2XVWypgkj1MA)+(Z7z!~Y zfKiO@j`+#6HyIC;G{YE6)pbS%jH4t$HUzJF!_=A(vBWtdhY$sd2vObu00i%p z43ym7R?fIu2;p>|lt?*>Ji&p+T((`xN(iCrv^3VEP&Hj8Oe1V(2#ou@V3c!8Xh@R8 zSmzNAL2DbFHa?i(bbzSVt#ed3Z=J7ObGdFA@nwU}%~Ec~oxRh^(aSQv+b(-$S$F2+<+4#U^5`%HVDiJNQh_oY zdS&ta2hVobE{M?EV@g2NkeIRKSuvSR=erJuMFP5t8KiNlq%q$1`-2dIcEp%2&GV~; zdi!y^sI9Q*ZbJY0AHF!+&;RY)cYl0yW4#6v>S`exRTK<4%sGoVO)<01-Z?%f)-_Cq zBf0JfhDI2vEnq5)BM3>O`-4)u@b2Pjy>kH~N=eDx1Zuuk56jv0kj$%FCG?BeuZi_0 zD^YKX2=sbs;X|iQ(7Tgy27|rvAWsNk)ox?jhB)Vu4-E308^r(m?scnL&SM@ylBEM` zzkM^e8dPsK9&n6GV>_W`D}uHb8;vkx1P4MHfK@8E0EjTm7(yuKEJYb0LI^R)$4PH_ zQ;X&1{*bE<@hpAx;E<@!QnuS|fNqN@8x0;jczEair$2rF`Z59)LSrouh*-qD_a0%4 z3CPZ1BN4J~yVHhP!W&WVM16P+!Pa;8@Ph*;I(&HcGz#5jYm)v5t?2P=ziQk2fA_E) zq~+l8W>IgpTNY&`#Kt@Cku#QH6hZ)pU2rypknkig(yN7hf3W}-qWM)fUp$e{y3PsiGMoaIN^LD8Sp ztDD>T%3#)Mg-FO^u4Tg_26eRpT{BD&ML?-&wu2(;XIyS}lY`OWlb`+NLce~apYIQ= zN>3j=B}izs=m)>MT_xJ=4~LbD*^_~+uK(vhe|vGhpbkA#v->~bqth(ZVtt$VftF%fi~jiDXm%o- zl9b3r9VXZqjjhAMSqp#za>ignfS@1^_GmatXQ}F7N_VVJpzkKujD?OTxcJ(bw(cVes331w4rMHVp)hT%Er1DXg129VT>1@Cg za&q_h>ED07|LDn||M=@W$Aj5tpGV^dkoH8^$Xz|ZI^WDMoDoUE_x9NN)r;x=p%qP# z3bz{8-N_HI$~KY=fC`Uxi_7^`zs~3s<(>||94fp-~Hw{dBQr?%%;7v$lhO^Yhs57Cug63iCny{ zE<}en(&7M}^i2q%=^pnH_@I$@*m@fb#(~G>7hirEGxqAwFE&?imFnuv`uxqSCM3*t zIna`4<3aCweq)Uhb#wmy6GAVWrsfcAJXzT?NqAZ$3BC4Ifu@+m#;yf#X(2TL!WF+fNLgUq9X|yYNqMPP#FTT0{{hNcnTp@;%!V6A<9vH0fI)pu{N*LKJIz@n&WY%4=njHGZR0-}}j zA@Vd9MpEiio@B*1N_*4EUcFfeZN|r^!pEQe_@|G*ct(J?jS@m)j4|@pmsg+O-!wud zc{(0WaL4ds1;o$C+-hN8IIY@{~5ehGvShy=3pWxZ)^-ERN;ci+zif~H}J z=6Tk!BxEt~^?b9F&2HWZXF69zBhHvr3V=iev6x~^(i{jN>ui)j*-zK&x(>b^j)c{J z`>^`{yYuO=oF1i!_`O~M7ax(Tcrms-F})lOhtRa5s>Eg$ZB~eS>O*kOA@2x6m|zMZ z9SIgOZ>(55DIBn>5!#{kPJVc^s#dLT-K;lzeD`oC5jeJIqG!r0Vwb#!`(r*Jb|`$c zPO2r9EjSyYK)m&x?4l^5Kq>B;HoPrect bSIz$e7Ps4zwPGKD00000NkvXXu0mjfi5=SH literal 0 HcmV?d00001 diff --git a/tests/data/automl/data/cifar10_subset/frog/0002.png b/tests/data/automl/data/cifar10_subset/frog/0002.png new file mode 100644 index 0000000000000000000000000000000000000000..8bec65a8e323c078b5c1af316048b5954735c11b GIT binary patch literal 2388 zcmV-a39I&rP)2V3nNmH>M#uIs!F0T^u4;N0RWg05fO{Bu3O6?Dr?9KM*ZotwgcgY zuik&Nxxehn-TZbxio;&tJMBk;K(u5OHyPG-*E(X^R>686r{k(D5m1yXkWIa>_7B9JbhYVJPV9Ly=|x5bx$K+Xu^)yIjVmg_aMw9)D|SWe zI;v``s;&eIWB>+W1VG1gh2ZV#P8kF1<-D3N1$Qn&OlEd?nCJbfYGyO{;4JQ^h2oMOs#Kg=1n6+@Z*bmda+8f$u#bnq!Iy#bCuL?uZ9FI@QrMt{#lfk>x z-PX<^IRrPjsPgEPq5_q(E*CWn%Q zLIuD?%*+T-JuI&3c9#CFn%~{HQl(+&N)SpaC0vBE&Nun~?9iS(i^iwZVLC3>tK0SO zwssc#v#H0aslU4^X3w4soxa_*QUWp%v9$&O0AahjS*~B7pP#7p{9Xz){p4UW30z&) zc3*eM#)++a677!sT@UZv>iOrVztl;-T|;>t>A}DM@70@m^YoX${`BnS`Y+$jFFOK3 zM8*&?Ab><;x>89{!~TqkJU{N(3%y_vD~(jqNSzi#?+#?BuX4 z^Z5))f=SdB8+m74n?DHK5U>DLLLr0gtlDnZ%#0`y0R%flAl>`Tylm=##)pGCXMQLa z9iob-N(xt`y>J#d;Sc#$-fnMiSI3h>-}Aeso*R<}((zrk0EnLF0U1ayi)|D12D!zn+l9zL_=DZIZ%6*3*AF05kIkd_N1gzXiAaLXtE!^fiIV&- zc=xtL(}vBa*8?}yTGh$|LTjxY z6*Ds+AR?lXTD2CH@3GPr#KrepZRPWqGlQhF(p98TQ`c;ZeD!gIl*CbbF!U$HSJnI6 zuHN34er?X2BywE~3P`G!IwC><6etl|w1u(pV4~BQx7$>Nm#^k>T zi|w*C{6Mu-7<<008!zq84o}Cd*qRJmX0v+OZ$Z?qJ2;sB&H3}YRsQk%I>F>Cy4JM2Y+D=S`e78M5o#>zx~kdJ`>wpW zzu$ql&$6_a9*@$m|MX|2_0LX^zIb}BE&7N~;FH8rm>k53=V;IO@sEFcrG#{S52av6 zEs?0}jO`c&APTeqh(ds{WXrB=L3B+EZA+4N#vs$vUOEZ;zUyegUDtuD6#CN63zrGKd9&wbod(d)PQ#qp*>{ zZ(hIo^2;yJJ~?~->GP&D^X2m4VWXHS&vO7kL?Psl|0<=D5(WG?H-$jNwlxSMjG{0M z5s}H7w%x3kq3=~iW{K{XtKMjg%9*bpKngbuKmXa!mGApR1ONzth|KiJ9sseW$Akkj zGiWJ*;l=gYOSs;jEE>FJ(D8cSnYb_A9o0AD!RfZ&LKqa!#Fz<~h5VF(8x%plvN z#hRt3N7GAHR%KPLFZ1mj_XiaS~TmMV>tV z@b>20FMjj#Ze2uS0wL_W4nugfn2%k5GM_&>KD)ZP{lyp0^GSY~&$N(X98*R`Kin0= z<>fk?B!|bDg6`$J>!NRO*SCn#`Eg1d%rZ{ey6&2;t@@CMQVMHrk|c|RIj3|Ig;^B8 zx_o_Gl`x2UW#|N*pPoqI+fg{+5kWNKLXN$pNg65aM=;!bcqB$0XQ4M{lq1=0?&Kf> z8bD%}2T91oe40}n6sxUNvM>??Z>-OG&KTbJ_4(nwbyZbe%|o6~c-si0u`?D^8qH!B zo7pT4TdT&=TBn^SRkNR@S&%Tsa2RobV7c8EtI`;7dUj%+5u!i8cYo}pG=ng*7Ja}u z1U`*If`E?jAOYokRBE)sO9R{8&{|}$)XD-+!Ze&8P8r2xudi-aW#0sZbc4!g(;$qu z>$TCg5j773S&b0kbzNbEVnUEsfa2NVY_+d61bP(4DMqoMyqX!o!M(fq> zhF}7`Q{&Kg9iapv*w$5+CRv*7cl)}jIp?Aqq#C6fX7k(uGru=C%+L@99B8eu=K+n! zVKi0}1}USQ#iXfsx9`>j@sq=&_a59Q*sqFWGRpx#lptqx-?kRo&boM+7-b*^QJM^L z^aQqcD4Gp%nBjmDR_zCb5e@+3m{ST2l!Jyp`}C8DCvDY2YtBzkFvf#Xo4Vc<!ECs6PM#Xl=7`wi?MhKrAoo)7I)%8kQU_oqyG)#(p!)TDD zF~A-IfIK6R#UYha660LR!hP4ZjRy`>=+i9v<3IZ0VzK<-JKuly>t~m*URK+E94DLG z&B4JzU-k92NrGg%--ZE)+MV4$MU-u}C2~O8fp|nm7=)<}JfV=x7t>)Fs-_t`wJGYe zlp7_wgLs}Wevl_a)u^_eviNOz z<$=i#!!(PK#dTEs^0GQ4xoM035g${-+OeDja`M~Qd>9$;5zkBDDn9pWH2o^A-3`Laj zbX#<hpbBke?O^eO0xO_XG&8E{u80Fvp`1t(Edw=@lk1tM_t9AYFzk2rO@;XbW z`6P*AR@b-}L)A1S2tse19VB!}bh2#)$r!*C<6s=z`Tg|agQID*NVCHazVr0<`ifD~ zbdsm@ILp&KgwTHWzc2swm%sYzs~2I&F$CivZ~IOvrHwXP6A~fSjz%ir5d^611%mkS zZ~?KCgLrfKCeO$unp|Ff)3xIL_a9P9UcJ7J@+n~W>gxLLZv8)>|GI2;K}0Ztm_Q|! z5Q1^a7>lA9d8#aSfZ$Opsnw<|A;y%EIA+7pIg72~FbJAvw^&Xey!VLbISRv*^Rv}v z^^gDb&;Rm^U*4@(ltLCjsd{Ubb51EO#E3APyFo~!?^bKh!0EkX zis&%-&pvzh*FXJf9A($5%}rUBqWk!)5 z$2>}Z`}}LEo%f)wnqCZ5U9Z*~ZB(3u(`il^yS*#gR-{?RLK-Fk3yBB7d5?gHl!q|j z0S|!#O6gvR>${t#>$St3v>*N8<0y*0e*Jpi*FX8QKl#CjPbXn~ygYt%@x&X90gmD* zOt|y#>dpH4wn*}LelmeXJ7^KXD9bX=IiVC|Y>bJbD9@)_%iq8F@^@c6*A{HcI*gJ* zn{TggCVBq7?|$@GpMH9{T()iBwW4jsu57lO{V-^)U0LoO@J-Woo%9|kJrM7lF-9pB z&>##$2q6Ff5eNYowfxO*pWocBmxpuA!WWlswe?{#@!nHL+qMaU064hcH&RLjL6$N~ z*lt%zC4r~oI09(k`wy3#^Ei$P!d=?|0HP>LlE@h;m4XmzV;Ku~+s$%5J3cy8N_z;W z2g|#n*c6+91(Qi$*R2|*ch*AfpwiB!*#QfnQUf7`D5VG?24e3$gwPn%G#&8X1MR)W zlosWV2VpBlv@HpQ+uiQwuCM?AU{6l(wQX;Vbk1p|dBV~>Z~KN2&jOqz2LvKS2mzRC z?SO-fAqSm@kRU9j^uPxMdEmo1;sLAMHcbKtolz13vN)a(gOE}p43!a-G3x*(L7JzY zs=D8Ufx(YHJWcZ{fTS*KWE4i$QKB7y&Rb)GAaKqNVvOR5bMBl4UdNo9Q4>fVMjk^# z$S@2+KqzH(GkOfd3?vKM4pR1(EDjPe_CjhxKpZjaOgCC+EM_Dy#i!8oh6Yj zw&nMpezz!BzFa{(NFlt(#ux}8K?p$ao#z21nB?h1Yc1s9F~o!tf+;1_>9p(mM~@yW kI|!*6i^|OoU<5tv&A{#wpP90rDb#4!=`?AUCbvX)VA zY?4hLo~`z(T;%)kAAbKYo6UAQE3%x?L^2_mpp+07oOYfhkepQob53!|XNwA@gc9ur zw?CzSP7RVD=x91+GSVa{kKK>Nfvw2e*30q_v4HMTCMlj&T*F{L0mCd%|hA%uYQvt?dN z7agLQ07RS^V2q-31~`*>jtPnXFQt?sXC#xHF-95Vf)`bm1(#9`F%W`NLIfwtAu&GK0GuOCtg(dPOmgQ< zuX@HIQQEdGc*LoYDZ0(J0U3*;Kq$r(Lx>dE>U@SmaH^veV+sjHgh+~sU_=Ru&gXfK z3EA%UW|V_b^{mRc^vT7D9KvP21PX#;iahCuem$9%)q4BV^@GAV0V6)yFm@6G1I)7= z7)~i!rw{_;aMNnUbD0V3l8cc6L=*!G+9_`|Lg>}i4IpdsNri+&Bx7RIYjIMc9Yah8 zC-dp^BX;`z?GrxI1byjx&zx3Lf;G01B}O4I zKbtJLAoZklE?Ju(flQS5x^I;~Xt2q67lU^pA?MR{wvIJ=I%-OZHP#wOy6(C2&Rc}y z_GQBXU7jye2-Ycv{Ka{F^Y)fuw0U}BTrx_*;Vg*^B>MKPiL@XLrQ~C9&L`(lN;rg= z2&IBEnlPtIsorReMDR!r9S5ZKsh*PC+w)lo^`s!gFBWT{@agmYu^GWz>n*x{x4!>+ zAS7|XUDJ7`av=o=+xIBOi^T#1QI7Wgs5K^pauA*%giuVrw^6+ecbnt8fBr4tU~D{` z^rz3?KK<<%tsE4@FmIcu9=wW(JLS!A@;c-qtEU{0fkYx9ivc;KL?*bD zO?M&y!+`(z=g(C>efQ>Sw>xe&f7v|UqloLGg3IgqpvP{skc-hJo@m9uU~)XP`^{su zTzSI2eA&FZS{E#Nmn+#92}94+DDK-CQrUr-+ha=mq93VuWytXlO^R#u#HM zif<45NluqPyk4$nB#~d;+`w|mGd?|a1md8_;}A?)U%y`0yI&r=F&m6VoD-nd2c^R1 zc@Hw1Ov+&xk_{ z$3YH2Ju<@E01Gko~)38PPp!PV8- z{?HH_ooNvVnMs7wWKy2`{@c@D>nMQ4m?$HJKtu>7Z=7d<&dwIz#1wr%QClauxLTI4 zE*5=v+8&=tFxFZ3a%evM^;17OUSzEqwADoh(e-1un-u)~eC;CcdZ)EjLx-`SOe>KU z%DV0T5L}#;C1(WTI4O&Qi{E{JTQKtNZg=;*gV{u4O0@-9QA5;+&! zkVP?CjS$=IIw6_Ng=#$V;q9-^e)Im-V-K(UL- z0|~*641_VpSwn#~P4m3#MOL=G?oI~BBn9w{b4lNQ_x9uGFNg?&5J6-d6@o+=W*BFH z|L)zTwf)me^ZD^`y_y5(!gwzv*btmGoY9->D;MHsdmK72sAzM$2p&#IPKm9aaef{{C4|OIBk)4`hR31gl3D}ql>P10DV2n(v z91;#dMv%dq(KFYNqt=!Q+#5Zq=EmX2$K98&n?+84^Zg|ic(dC=zPkFSH*c5A#n+$z zdg_&Os2iPko|8aPm`~YkR#?jxvqf___Wgj8XB1tmW>O4&vT`w!7OLvgxX*SI;l)sT-X!lF{X2vRYO}!DCF$M`dh? zXjw1nOdx;cppvtNtW-CgngdR{EQu(^NvROhd{V>z{MUc4B#pLv_`13M;m76qt9>(U zcZd4*4Py|z<*eNAPIv8|fgF^<6uce@+M^JQOpF$X=*MQu&9Tu2Jsy1->Vou3uctLuVXzZa89b$fFg6FwgH#tjjn=+Wi+>et^VxZgei zm(IiS)F70)!(nxHQI)m5f6}y)$j&wG=>Ie0ILWAdMQqci;00000NkvXXu0mjfn-`hl literal 0 HcmV?d00001 diff --git a/tests/data/automl/data/cifar10_subset/frog/0005.png b/tests/data/automl/data/cifar10_subset/frog/0005.png new file mode 100644 index 0000000000000000000000000000000000000000..a7443b4b46f8d9072983de6239025447d88e6eba GIT binary patch literal 2451 zcmV;E32gR>P)v(E)=n0yNH%+AXQoG%C5`3Sxe%nlfnz|(ZSqQalpqfgz}XQeMm&OTLGpOq z-P7CjLb6E~YdKYQmjBO1zK8heKm6n0GY(Sl$oD0Wik&eGd*i&*Oqi>3LC@B&letMoLP2p`3G`v(_4G0`O1< z5`}fnSzwHLejtTMkaO-)LWwf~z!+aUqLx#xL8N+Ep7r2|B?gb*PFB03~sEdUTg5YahDh$KiTLrOt7 z0*v@kFx;QLI6OScvNVbk&VssHKO%8xfk|hL0}IY^&Iu*pImNZ%7eB7ztU$McgP%#Y8elLOxqgg8Wvx{iJS#f#UJNr&j1uA5y^th>I{w(a!J z_L?|GC?Ys?7M+uvGA0FOju6h65YhqZIzyi>f1BqAf=S}Evpt}u#8GFM45$|%0MV`1 zO<8SG^2+J^r`uW;PPZITaLJi)paCqS9wFQSNZ;d}BO)OL0K_ai#NIm7QL8Ai6nbkl zI4k{#@QzZ#T}w?HShL=&tY8V3yB@3o#<+2Gy{qe{iIX5r10lHdgzx)=5MvA?a?S;o zr334np}Hkt9O$-Hlu~9vhQwGnW| zFaMlzWQ-9)7-N)B#<_5Pt-%@FyIvtU2e$8eNeMA(TRdMp7B>&i>+O%!_^ zef!7%`tCjTlDgSh+p%xohJKi3+41qw1cR&Fhe~N{to4+h;9*F-#SKFt@>}yZ@+seWH#OH3dhB+c4*Ds#P`Bb^}6l*qSyj} z@B2aE9Ubl8-QBl!{q=Mj z@r$4R_{Sd)he7DevZ@_&q~hx4gKGOYj*}z?04ezU%gbV0TFu^{KX^1r!qn$pQ&+03 zm@1y@?Q`e)Lc7PRKVO!=xw!wIx9=Z|dbwKu>o5Nab!$~^RV^7c#x7TD%6K};2E$?7 zb_59`wp|B=|MC3lm;dtb01yXJ5Tuk!-5X&xTSZu#r<-c|=J#)xWp{XbCI;!u;b0JL z)|>mQyT0r5ac(W?x;~7o@5MrTLV8NIO<57l^XcUJ>O)nk@<0D=HjBRgt1q(Q)B@S; zmZE98lF0jN@%{VT4~yq89X~!iY}RWb_~YYpGR}3^4uatHnM!+z5@hA zjG`zELjdsoz{RxKHNX1r-yUb#*?d1A#9bSS8p&-{|NicARjD8ufrIU40l=ENdwP1> z6}vc=XRppIfs@lY+}_u9O%TH;(djWQS}sW#U0yfBQdb%=C2olR%6U0Hd)-*z1U1l0%#$m$#uqyAjb>Eupwz#{!6GBY)_To4+#$?&(<=HDH zqQz=klug+*RPaw;e-;lDLNE*jW0;N-W3-fh;QLKe5@ymgsvGsYcbAv9x64{<0GY-BpGndjnzWZ&p-dPZmO%R+bkdFOOnhmYBVHhG}qpI!q@MRnaaU2UF>b7wX)|)jKWI4;PF3yD) z1jBgXhk=wb&$6edr!Wkpq)8I1wj5={pa1-4C-bAhFp2{C^5nQIt8q5s1Y^&0&N0R% ze}Y2O~U*6X{7W@R1Yg3vimDC!|c0LG4G{V+g+hsOtk(WOA6n_&0|wx%HpWm&tUWjX7L7Hgv#ukA#BmaZLm@=hb=KM_E%c>&|A+VU7cideNssk?)3vScj5XLaoodvs z*m*u148zHEIv5Pmb!AoceZ^J9t8Ee_i7$To)nAOJhm>+LnH>__$8m5v8(m(~>usgW zZnxXs+}y6#%d@l7LFl6aiXa#bg6Q`4L)*5;>HZ)ZpuxuvkDyiGY3S^qjp7$CUyaAJ z;b1hLpN#en0H`3062wsYlQexX%g=AWuXr+^=8MJs_4&KI%ZoJ0$KwIxjyQl0n_XEq zO(CS-l*SrxfCSoVHy;gu@WYP|=JVm+EX^ljk}@W!U;-&Y>p0=jaC$aBd3k;J-J*6V zKKbOM**LpAzqq-)xVc0>@FL$6K+(C_4?EQvss-m->!MLiH%G5dM{$&RVdw{5;7dZN zF_beQglvsLV^GF(NZL~l8^hckU>&L^{H1xc>ZXfQKwW|lQpQdRPh7^ed zsFkYfX4m%fVYD|M2`L0)1i)(3bxu&u1ruNpDAT$^=F?<2Jvjc+$>Omy*B=Vzns7L% zw%e0~{BV}NJHIKamISTjk{YM_u2Nm^kPtcxlSG6R2wFpQ4~;@Ea10UH{{bvJBdjvK R> Date: Fri, 1 Mar 2024 11:05:02 -0800 Subject: [PATCH 06/42] feature: Add TensorFlow 2.14 image configs (#4446) --- src/sagemaker/fw_utils.py | 2 +- .../image_uri_config/tensorflow.json | 91 ++++++++++++++++++- 2 files changed, 89 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 34abfdc76a..1fb8c8eaaa 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -90,7 +90,7 @@ "local_gpu", ) SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS = { - # tf 2.12 should not be supported: smdataparallel excludes support for tf 2.12. + # tf 2.12 should not be supported: smdataparallel excludes support for tf>=2.12. "tensorflow": [ "2.3", "2.3.1", diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index 3453f5f120..1aaef31c78 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -330,7 +330,8 @@ "2.10": "2.10.1", "2.11": "2.11.1", "2.12": "2.12.1", - "2.13": "2.13.0" + "2.13": "2.13.0", + "2.14": "2.14.1" }, "versions": { "1.10.0": { @@ -2135,6 +2136,46 @@ "ca-west-1": "204538143572" }, "repository": "tensorflow-inference" + }, + "2.14.1": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference" } } }, @@ -2316,7 +2357,8 @@ "2.10": "2.10.1", "2.11": "2.11.0", "2.12": "2.12.0", - "2.13": "2.13.0" + "2.13": "2.13.0", + "2.14": "2.14.1" }, "versions": { "1.10.0": { @@ -4190,7 +4232,50 @@ "ca-west-1": "204538143572" }, "repository": "tensorflow-training" + }, + "2.14.1": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-training" } } } -} \ No newline at end of file +} From 00327fc1d173b5c6a3efd0704d224886a6936b1c Mon Sep 17 00:00:00 2001 From: Rohan Gujarathi Date: Fri, 1 Mar 2024 13:05:18 -0800 Subject: [PATCH 07/42] fix: remove enable_network_isolation from the python doc (#4465) Co-authored-by: Rohan Gujarathi --- src/sagemaker/remote_function/client.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/sagemaker/remote_function/client.py b/src/sagemaker/remote_function/client.py index 49091fc60c..0dc69d8647 100644 --- a/src/sagemaker/remote_function/client.py +++ b/src/sagemaker/remote_function/client.py @@ -694,11 +694,6 @@ def __init__( encrypt_inter_container_traffic (bool): A flag that specifies whether traffic between training containers is encrypted for the training job. Defaults to ``False``. - enable_network_isolation (bool): A flag that specifies whether container will run in - network isolation mode. Defaults to ``False``. Network isolation mode restricts the - container access to outside networks (such as the Internet). The container does not - make any inbound or outbound network calls. Also known as Internet-free mode. - spark_config (SparkConfig): Configurations to the Spark application that runs on Spark image. If ``spark_config`` is specified, a SageMaker Spark image uri will be used for training. Note that ``image_uri`` can not be specified at the From 4e5155cf60aa1a2c1fd89a54a2672b671fedf86f Mon Sep 17 00:00:00 2001 From: cansun <80425164+can-sun@users.noreply.github.com> Date: Fri, 1 Mar 2024 13:34:23 -0800 Subject: [PATCH 08/42] doc: Add doc for new feature processor APIs and classes (#4250) --- doc/api/prep_data/feature_store.rst | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/doc/api/prep_data/feature_store.rst b/doc/api/prep_data/feature_store.rst index 50a10c5089..731b2e32d1 100644 --- a/doc/api/prep_data/feature_store.rst +++ b/doc/api/prep_data/feature_store.rst @@ -60,6 +60,7 @@ Feature Definition :members: :show-inheritance: + Inputs ****** @@ -181,9 +182,13 @@ Feature Processor Data Source :members: :show-inheritance: +.. autoclass:: sagemaker.feature_store.feature_processor.PySparkDataSource + :members: + :show-inheritance: -Feature Processor Scheduler -*************************** + +Feature Processor Scheduler and Triggers +**************************************** .. automethod:: sagemaker.feature_store.feature_processor.to_pipeline @@ -196,3 +201,12 @@ Feature Processor Scheduler .. automethod:: sagemaker.feature_store.feature_processor.describe .. automethod:: sagemaker.feature_store.feature_processor.list_pipelines + +.. automethod:: sagemaker.feature_store.feature_processor.put_trigger + +.. automethod:: sagemaker.feature_store.feature_processor.enable_trigger + +.. automethod:: sagemaker.feature_store.feature_processor.disable_trigger + +.. automethod:: sagemaker.feature_store.feature_processor.delete_trigger + From ac4e861c222232db3d65d7d2bb6113df4cab76ad Mon Sep 17 00:00:00 2001 From: Justin Date: Mon, 4 Mar 2024 10:11:54 -0600 Subject: [PATCH 09/42] fix: properly close sagemaker config file after loading config (#4457) Closes #4456 --- src/sagemaker/config/config.py | 4 +++- tests/unit/sagemaker/config/test_config.py | 23 +++++++++++++++++----- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/config/config.py b/src/sagemaker/config/config.py index fa30b05a0e..23b3957905 100644 --- a/src/sagemaker/config/config.py +++ b/src/sagemaker/config/config.py @@ -181,7 +181,9 @@ def _load_config_from_file(file_path: str) -> dict: f"Provide a valid file path" ) logger.debug("Fetching defaults config from location: %s", file_path) - return yaml.safe_load(open(inferred_file_path, "r")) + with open(inferred_file_path, "r") as f: + content = yaml.safe_load(f) + return content def _load_config_from_s3(s3_uri, s3_resource_for_config) -> dict: diff --git a/tests/unit/sagemaker/config/test_config.py b/tests/unit/sagemaker/config/test_config.py index 9e89a5890b..35135db81e 100644 --- a/tests/unit/sagemaker/config/test_config.py +++ b/tests/unit/sagemaker/config/test_config.py @@ -34,7 +34,9 @@ @pytest.fixture() def config_file_as_yaml(get_data_dir): config_file_path = os.path.join(get_data_dir, "config.yaml") - return open(config_file_path, "r").read() + with open(config_file_path, "r") as f: + content = f.read() + return content @pytest.fixture() @@ -42,7 +44,13 @@ def expected_merged_config(get_data_dir): expected_merged_config_file_path = os.path.join( get_data_dir, "expected_output_config_after_merge.yaml" ) - return yaml.safe_load(open(expected_merged_config_file_path, "r").read()) + with open(expected_merged_config_file_path, "r") as f: + content = yaml.safe_load(f.read()) + return content + + +def _raise_valueerror(*args): + raise ValueError(args) def test_config_when_default_config_file_and_user_config_file_is_not_found(): @@ -60,7 +68,8 @@ def test_config_when_overriden_default_config_file_is_not_found(get_data_dir): def test_invalid_config_file_which_has_python_code(get_data_dir): invalid_config_file_path = os.path.join(get_data_dir, "config_file_with_code.yaml") # no exceptions will be thrown with yaml.unsafe_load - yaml.unsafe_load(open(invalid_config_file_path, "r")) + with open(invalid_config_file_path, "r") as f: + yaml.unsafe_load(f) # PyYAML will throw exceptions for yaml.safe_load. SageMaker Config is using # yaml.safe_load internally with pytest.raises(ConstructorError) as exception_info: @@ -228,7 +237,8 @@ def test_merge_of_s3_default_config_file_and_regular_config_file( get_data_dir, expected_merged_config, s3_resource_mock ): config_file_content_path = os.path.join(get_data_dir, "sample_config_for_merge.yaml") - config_file_as_yaml = open(config_file_content_path, "r").read() + with open(config_file_content_path, "r") as f: + config_file_as_yaml = f.read() config_file_bucket = "config-file-bucket" config_file_s3_prefix = "config/config.yaml" config_file_s3_uri = "s3://{}/{}".format(config_file_bucket, config_file_s3_prefix) @@ -440,8 +450,11 @@ def test_load_local_mode_config(mock_load_config): mock_load_config.assert_called_with(_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH) -def test_load_local_mode_config_when_config_file_is_not_found(): +@patch("sagemaker.config.config._load_config_from_file", side_effect=_raise_valueerror) +def test_load_local_mode_config_when_config_file_is_not_found(mock_load_config): + # Patch is needed because one might actually have a local config file assert load_local_mode_config() is None + mock_load_config.assert_called_with(_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH) @pytest.mark.parametrize( From b857eadb792200cb036c34cdb62ae58d72f1cdc7 Mon Sep 17 00:00:00 2001 From: evakravi <69981223+evakravi@users.noreply.github.com> Date: Mon, 4 Mar 2024 12:33:33 -0500 Subject: [PATCH 10/42] feat: instance specific jumpstart host requirements (#4397) * feat: instance specific jumpstart host requirements * chore: add js support for copies resource requirement, enforce coupling with ResourceRequirements class * fix: typing * fix: pylint --- .../artifacts/resource_requirements.py | 66 ++++++++++++++---- src/sagemaker/jumpstart/factory/model.py | 1 + src/sagemaker/jumpstart/types.py | 23 +++++++ src/sagemaker/resource_requirements.py | 7 +- tests/unit/sagemaker/jumpstart/constants.py | 20 ++++++ tests/unit/sagemaker/jumpstart/test_types.py | 23 +++++++ .../jumpstart/test_resource_requirements.py | 68 +++++++++++++++++++ 7 files changed, 193 insertions(+), 15 deletions(-) diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index ecf6d1b5ea..8baaaafd2a 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -13,7 +13,7 @@ """This module contains functions for obtaining JumpStart resoure requirements.""" from __future__ import absolute_import -from typing import Optional +from typing import Dict, Optional, Tuple from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -28,6 +28,20 @@ from sagemaker.session import Session from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements +REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP: Dict[ + str, Dict[str, Tuple[str, str]] +] = { + "requests": { + "num_accelerators": ("num_accelerators", "num_accelerators"), + "num_cpus": ("num_cpus", "num_cpus"), + "copies": ("copies", "copy_count"), + "min_memory_mb": ("memory", "min_memory"), + }, + "limits": { + "max_memory_mb": ("memory", "max_memory"), + }, +} + def _retrieve_default_resources( model_id: str, @@ -38,6 +52,7 @@ def _retrieve_default_resources( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + instance_type: Optional[str] = None, ) -> ResourceRequirements: """Retrieves the default resource requirements for the model. @@ -63,6 +78,8 @@ def _retrieve_default_resources( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + instance_type (str): An instance type to optionally supply in order to get + host requirements specific for the instance type. Returns: str: The default resource requirements to use for the model or None. @@ -91,23 +108,44 @@ def _retrieve_default_resources( is_dynamic_container_deployment_supported = ( model_specs.dynamic_container_deployment_supported ) - default_resource_requirements = model_specs.hosting_resource_requirements + default_resource_requirements: Dict[str, int] = ( + model_specs.hosting_resource_requirements or {} + ) else: raise NotImplementedError( f"Unsupported script scope for retrieving default resource requirements: '{scope}'" ) + instance_specific_resource_requirements: Dict[str, int] = ( + model_specs.hosting_instance_type_variants.get_instance_specific_resource_requirements( + instance_type + ) + if instance_type + and getattr(model_specs, "hosting_instance_type_variants", None) is not None + else {} + ) + + default_resource_requirements = { + **default_resource_requirements, + **instance_specific_resource_requirements, + } + if is_dynamic_container_deployment_supported: - requests = {} - if "num_accelerators" in default_resource_requirements: - requests["num_accelerators"] = default_resource_requirements["num_accelerators"] - if "min_memory_mb" in default_resource_requirements: - requests["memory"] = default_resource_requirements["min_memory_mb"] - if "num_cpus" in default_resource_requirements: - requests["num_cpus"] = default_resource_requirements["num_cpus"] - - limits = {} - if "max_memory_mb" in default_resource_requirements: - limits["memory"] = default_resource_requirements["max_memory_mb"] - return ResourceRequirements(requests=requests, limits=limits) + + all_resource_requirement_kwargs = {} + + for ( + requirement_type, + spec_field_to_resource_requirement_map, + ) in REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP.items(): + requirement_kwargs = {} + for spec_field, resource_requirement in spec_field_to_resource_requirement_map.items(): + if spec_field in default_resource_requirements: + requirement_kwargs[resource_requirement[0]] = default_resource_requirements[ + spec_field + ] + + all_resource_requirement_kwargs[requirement_type] = requirement_kwargs + + return ResourceRequirements(**all_resource_requirement_kwargs) return None diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 0e1dbfe07d..9448e45cc2 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -503,6 +503,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + instance_type=kwargs.instance_type, ) return kwargs diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index cd10a7123b..8c74e007ae 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -512,6 +512,29 @@ def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str instance_type=instance_type, property_name="artifact_key" ) + def get_instance_specific_resource_requirements(self, instance_type: str) -> Optional[str]: + """Returns instance specific resource requirements. + + If a value exists for both the instance family and instance type, the instance type value + is chosen. + """ + + instance_specific_resource_requirements: dict = ( + self.variants.get(instance_type, {}) + .get("properties", {}) + .get("resource_requirements", {}) + ) + + instance_type_family = get_instance_type_family(instance_type) + + instance_family_resource_requirements: dict = ( + self.variants.get(instance_type_family, {}) + .get("properties", {}) + .get("resource_requirements", {}) + ) + + return {**instance_family_resource_requirements, **instance_specific_resource_requirements} + def _get_instance_specific_property( self, instance_type: str, property_name: str ) -> Optional[str]: diff --git a/src/sagemaker/resource_requirements.py b/src/sagemaker/resource_requirements.py index f0be00ea09..62389ba127 100644 --- a/src/sagemaker/resource_requirements.py +++ b/src/sagemaker/resource_requirements.py @@ -16,6 +16,7 @@ import logging from typing import Optional +from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts @@ -34,7 +35,8 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> str: + instance_type: Optional[str] = None, +) -> ResourceRequirements: """Retrieves the default resource requirements for the model matching the given arguments. Args: @@ -59,6 +61,8 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + instance_type (str): An instance type to optionally supply in order to get + host requirements specific for the instance type. Returns: str: The default resource requirements to use for the model. @@ -83,4 +87,5 @@ def retrieve_default( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + instance_type=instance_type, ) diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index a60f8f9315..43777ab14a 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -840,8 +840,22 @@ "model_package_arn": "$gpu_model_package_arn", } }, + "g5": { + "properties": { + "resource_requirements": { + "num_accelerators": 888810, + "randon-field-2": 2222, + } + } + }, "m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, "c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, + "ml.g5.xlarge": { + "properties": { + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}, + "resource_requirements": {"num_accelerators": 10}, + } + }, "ml.g5.48xlarge": { "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}} }, @@ -857,6 +871,12 @@ "framework_version": "1.5.0", "py_version": "py3", }, + "dynamic_container_deployment_supported": True, + "hosting_resource_requirements": { + "min_memory_mb": 81999, + "num_accelerators": 1, + "random_field_1": 1, + }, "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 7f842a053c..81a6c7dd14 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -34,6 +34,7 @@ "variants": { "ml.p2.12xlarge": { "properties": { + "resource_requirements": {"req1": 1, "req2": {"1": 2, "2": 3}, "req3": 9}, "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, "supported_inference_instance_types": ["ml.p5.xlarge"], "default_inference_instance_type": "ml.p5.xlarge", @@ -60,6 +61,11 @@ "p2": { "regional_properties": {"image_uri": "$gpu_image_uri"}, "properties": { + "resource_requirements": { + "req2": {"2": 5, "9": 999}, + "req3": 999, + "req4": "blah", + }, "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.xlarge"], "default_inference_instance_type": "ml.p2.xlarge", "metrics": [ @@ -880,3 +886,20 @@ def test_jumpstart_training_artifact_key_instance_variants(): ) is None ) + + +def test_jumpstart_resource_requirements_instance_variants(): + assert INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements( + instance_type="ml.p2.xlarge" + ) == {"req2": {"2": 5, "9": 999}, "req3": 999, "req4": "blah"} + + assert INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements( + instance_type="ml.p2.12xlarge" + ) == {"req1": 1, "req2": {"1": 2, "2": 3}, "req3": 9, "req4": "blah"} + + assert ( + INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements( + instance_type="ml.p99.12xlarge" + ) + == {} + ) diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index 86031fbd57..1ad25f962f 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -18,6 +18,10 @@ import pytest from sagemaker import resource_requirements +from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements +from sagemaker.jumpstart.artifacts.resource_requirements import ( + REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP, +) from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec @@ -47,6 +51,55 @@ def test_jumpstart_resource_requirements(patched_get_model_specs): patched_get_model_specs.reset_mock() +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_resource_requirements_instance_type_variants(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_special_model_spec + region = "us-west-2" + mock_client = boto3.client("s3") + mock_session = Mock(s3_client=mock_client) + + model_id, model_version = "variant-model", "*" + default_inference_resource_requirements = resource_requirements.retrieve_default( + region=region, + model_id=model_id, + model_version=model_version, + scope="inference", + sagemaker_session=mock_session, + instance_type="ml.g5.xlarge", + ) + assert default_inference_resource_requirements.requests == { + "memory": 81999, + "num_accelerators": 10, + } + + default_inference_resource_requirements = resource_requirements.retrieve_default( + region=region, + model_id=model_id, + model_version=model_version, + scope="inference", + sagemaker_session=mock_session, + instance_type="ml.g5.555xlarge", + ) + assert default_inference_resource_requirements.requests == { + "memory": 81999, + "num_accelerators": 888810, + } + + default_inference_resource_requirements = resource_requirements.retrieve_default( + region=region, + model_id=model_id, + model_version=model_version, + scope="inference", + sagemaker_session=mock_session, + instance_type="ml.f9.555xlarge", + ) + assert default_inference_resource_requirements.requests == { + "memory": 81999, + "num_accelerators": 1, + } + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs): patched_get_model_specs.side_effect = get_special_model_spec @@ -74,3 +127,18 @@ def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs): resource_requirements.retrieve_default( region=region, model_id=model_id, model_version=model_version, scope="training" ) + + +def test_jumpstart_supports_all_resource_requirement_fields(): + + all_tracked_resource_requirement_fields = { + field + for requirements in REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP.values() + for _, field in requirements.values() + } + + excluded_resource_requirement_fields = {"requests", "limits"} + assert ( + set(ResourceRequirements().__dict__.keys()) - excluded_resource_requirement_fields + == all_tracked_resource_requirement_fields + ) From 72fd0fa3fe109c2ca3604b7bf7845a2527a883bb Mon Sep 17 00:00:00 2001 From: Kalyani Nikure <110067132+knikure@users.noreply.github.com> Date: Mon, 4 Mar 2024 09:54:35 -0800 Subject: [PATCH 11/42] change: Bump Apache Airflow version to 2.8.2 (#4470) * Update tox.ini * Update test_requirements.txt --- requirements/extras/test_requirements.txt | 2 +- tox.ini | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index ba7d8c3849..69b4d0c2de 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -12,7 +12,7 @@ awslogs==0.14.0 black==22.3.0 stopit==1.1.2 # Update tox.ini to have correct version of airflow constraints file -apache-airflow==2.8.1 +apache-airflow==2.8.2 apache-airflow-providers-amazon==7.2.1 attrs>=23.1.0,<24 fabric==2.6.0 diff --git a/tox.ini b/tox.ini index 66e546372b..d2f5e67cfb 100644 --- a/tox.ini +++ b/tox.ini @@ -81,7 +81,7 @@ passenv = # Can be used to specify which tests to run, e.g.: tox -- -s commands = python -c "import os; os.system('install-custom-pkgs --install-boto-wheels')" - pip install 'apache-airflow==2.8.1' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.8.1/constraints-3.8.txt" + pip install 'apache-airflow==2.8.2' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.8.2/constraints-3.8.txt" pip install 'torch==2.0.1+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html' pip install 'torchvision==0.15.2+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html' From 892ba38c28b103e82e791558a31fea6bfe1cc423 Mon Sep 17 00:00:00 2001 From: gv Date: Mon, 4 Mar 2024 23:19:41 +0000 Subject: [PATCH 12/42] fix: make sure gpus are found in local_gpu run (#4384) * fix: make sure gpus are found in local_gpu run * fix: black formatting * fix: adjust unit test --- src/sagemaker/local/image.py | 4 +++- tests/unit/sagemaker/local/test_local_image.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index 7893ee9260..39c879ef6d 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -860,7 +860,9 @@ def _create_docker_host( # to setting --runtime=nvidia in the docker commandline. if self.instance_type == "local_gpu": host_config["deploy"] = { - "resources": {"reservations": {"devices": [{"capabilities": ["gpu"]}]}} + "resources": { + "reservations": {"devices": [{"count": "all", "capabilities": ["gpu"]}]} + } } if not self.is_studio and command == "serve": diff --git a/tests/unit/sagemaker/local/test_local_image.py b/tests/unit/sagemaker/local/test_local_image.py index ebca91a9f9..08c55fa0b4 100644 --- a/tests/unit/sagemaker/local/test_local_image.py +++ b/tests/unit/sagemaker/local/test_local_image.py @@ -871,7 +871,7 @@ def test_container_has_gpu_support(tmpdir, sagemaker_session): docker_host = sagemaker_container._create_docker_host("host-1", {}, set(), "train", []) assert "deploy" in docker_host assert docker_host["deploy"] == { - "resources": {"reservations": {"devices": [{"capabilities": ["gpu"]}]}} + "resources": {"reservations": {"devices": [{"count": "all", "capabilities": ["gpu"]}]}} } From 9a26978207d0181252cb30e072845c7f3aa3d5a4 Mon Sep 17 00:00:00 2001 From: akrishna1995 <38850354+akrishna1995@users.noreply.github.com> Date: Mon, 4 Mar 2024 15:20:10 -0800 Subject: [PATCH 13/42] feat: pin dll version to support python3.11 to the sdk (#4472) Co-authored-by: Ashwin Krishna --- tox.ini | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tox.ini b/tox.ini index d2f5e67cfb..d990467b3b 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ # and then run "tox" from this directory. [tox] -envlist = black-format,flake8,pylint,docstyle,sphinx,doc8,twine,py38,py39,py310 +envlist = black-format,flake8,pylint,docstyle,sphinx,doc8,twine,py38,py39,py310,py311 skip_missing_interpreters = False @@ -84,12 +84,13 @@ commands = pip install 'apache-airflow==2.8.2' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.8.2/constraints-3.8.txt" pip install 'torch==2.0.1+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html' pip install 'torchvision==0.15.2+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html' + pip install 'dill>=0.3.8' pytest --cov=sagemaker --cov-append {posargs} - {env:IGNORE_COVERAGE:} coverage report -i --fail-under=86 +{env:IGNORE_COVERAGE:} coverage report -i --fail-under=86 deps = .[test] depends = - {py38,py39,py310}: clean + {py38,py39,py310,p311}: clean [testenv:flake8] skipdist = true From 69a9fcd0cee94da94b6ae121b4859b7b7af1ebb6 Mon Sep 17 00:00:00 2001 From: Kalyani Nikure <110067132+knikure@users.noreply.github.com> Date: Tue, 5 Mar 2024 07:05:42 -0800 Subject: [PATCH 14/42] fix: Skip No Canvas regions for test_deploy_best_candidate (#4477) --- tests/integ/__init__.py | 7 +++++++ tests/integ/test_auto_ml_v2.py | 3 ++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/integ/__init__.py b/tests/integ/__init__.py index fbf32c3acf..434f4dd744 100644 --- a/tests/integ/__init__.py +++ b/tests/integ/__init__.py @@ -139,6 +139,13 @@ "af-south-1", "eu-south-1", ] +NO_CANVAS_REGIONS = [ + "ca-central-1", + "eu-north-1", + "eu-west-2", + "sa-east-1", + "us-west-1", +] NO_MODEL_MONITORING_REGIONS = ["me-south-1", "af-south-1", "eu-south-1"] DRIFT_CHECK_BASELINES_SUPPORTED_REGIONS = [ "us-east-2", diff --git a/tests/integ/test_auto_ml_v2.py b/tests/integ/test_auto_ml_v2.py index 7749b0c5f2..91802770d1 100644 --- a/tests/integ/test_auto_ml_v2.py +++ b/tests/integ/test_auto_ml_v2.py @@ -330,7 +330,8 @@ def test_best_candidate( @pytest.mark.skipif( - tests.integ.test_region() in tests.integ.NO_AUTO_ML_REGIONS, + tests.integ.test_region() in tests.integ.NO_AUTO_ML_REGIONS + or tests.integ.test_region() in tests.integ.NO_CANVAS_REGIONS, reason="AutoML is not supported in the region yet.", ) @pytest.mark.release From 0a48a8af364bb261ac47d355fbf80317c7bdbbd1 Mon Sep 17 00:00:00 2001 From: ci Date: Tue, 5 Mar 2024 18:25:26 +0000 Subject: [PATCH 15/42] prepare release v2.211.0 --- CHANGELOG.md | 23 +++++++++++++++++++++++ VERSION | 2 +- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bcc19b6b22..4f42089c8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,28 @@ # Changelog +## v2.211.0 (2024-03-05) + +### Features + + * pin dll version to support python3.11 to the sdk + * instance specific jumpstart host requirements + * Add TensorFlow 2.14 image configs + * Add AutoMLV2 support + * Support selective pipeline execution between function step and regular step + * Add new Triton DLC URIs + +### Bug Fixes and Other Changes + + * Skip No Canvas regions for test_deploy_best_candidate + * make sure gpus are found in local_gpu run + * Bump Apache Airflow version to 2.8.2 + * properly close sagemaker config file after loading config + * remove enable_network_isolation from the python doc + +### Documentation Changes + + * Add doc for new feature processor APIs and classes + ## v2.210.0 (2024-02-28) ### Features diff --git a/VERSION b/VERSION index fdf99214f7..1a30f1d422 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.210.1.dev0 +2.211.0 From 8036ad360cbcee6a78fcc4837e8e673bfd668844 Mon Sep 17 00:00:00 2001 From: ci Date: Tue, 5 Mar 2024 18:25:28 +0000 Subject: [PATCH 16/42] update development version to v2.211.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 1a30f1d422..e2c63e5edf 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.211.0 +2.211.1.dev0 From 55940adb44cea8966f0f83109883cf09ceb94078 Mon Sep 17 00:00:00 2001 From: Samrudhi Sharma <154457034+samruds@users.noreply.github.com> Date: Tue, 5 Mar 2024 17:52:39 -0800 Subject: [PATCH 17/42] change: Enhance model builder selection logic to include model size (#4429) * change: Enhance model builder selection logic to include model size * Fix conflicts * Address PR comments * fix formatting * fix formatting of test * Fix token in tasks.json * Increase coverage for tests * fix formatting * Fix requirements * Import code instead of importing accelerate * Fix formatting * Setup dependencies --- doc/requirements.txt | 1 + .../extras/huggingface_requirements.txt | 1 + requirements/extras/test_requirements.txt | 1 + setup.py | 1 + src/sagemaker/serve/builder/model_builder.py | 78 ++- src/sagemaker/serve/schema/task.json | 2 +- .../serve/test_serve_model_builder_gpu.py | 184 +++++++ .../serve/test_serve_transformers.py | 20 +- .../serve/builder/test_model_builder.py | 516 ++++++++++++++++++ tests/unit/sagemaker/serve/utils/test_task.py | 2 +- 10 files changed, 796 insertions(+), 10 deletions(-) create mode 100644 requirements/extras/huggingface_requirements.txt create mode 100644 tests/integ/sagemaker/serve/test_serve_model_builder_gpu.py diff --git a/doc/requirements.txt b/doc/requirements.txt index 3d5618ce32..a65e0e4050 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -4,3 +4,4 @@ docutils==0.15.2 packaging==20.9 jinja2==3.1.3 schema==0.7.5 +accelerate>=0.24.1,<=0.27.0 diff --git a/requirements/extras/huggingface_requirements.txt b/requirements/extras/huggingface_requirements.txt new file mode 100644 index 0000000000..31c6e65899 --- /dev/null +++ b/requirements/extras/huggingface_requirements.txt @@ -0,0 +1 @@ +accelerate>=0.24.1,<=0.27.0 diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index 69b4d0c2de..8bd665d1df 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -39,3 +39,4 @@ tritonclient[http]<2.37.0 onnx==1.14.1 # tf2onnx==1.15.1 nbformat>=5.9,<6 +accelerate>=0.24.1,<=0.27.0 diff --git a/setup.py b/setup.py index b1070319d3..5b8845efed 100644 --- a/setup.py +++ b/setup.py @@ -79,6 +79,7 @@ def read_requirements(filename): "feature-processor": read_requirements( "requirements/extras/feature-processor_requirements.txt" ), + "huggingface": read_requirements("requirements/extras/huggingface_requirements.txt"), } # Meta dependency groups extras["all"] = [item for group in extras.values() for item in group] diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 8ca6a5d4ab..c66057397f 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -20,9 +20,11 @@ from pathlib import Path +from accelerate.commands.estimate import estimate_command_parser, gather_data from sagemaker import Session from sagemaker.model import Model from sagemaker.base_predictor import PredictorBase +from sagemaker.djl_inference import defaults from sagemaker.serializers import NumpySerializer, TorchTensorSerializer from sagemaker.deserializers import JSONDeserializer, TorchTensorDeserializer from sagemaker.serve.builder.schema_builder import SchemaBuilder @@ -41,6 +43,7 @@ from sagemaker.serve.utils import task from sagemaker.serve.utils.exceptions import TaskNotFoundException from sagemaker.serve.utils.predictors import _get_local_mode_predictor +from sagemaker.serve.utils.hardware_detector import _get_gpu_info, _get_gpu_info_fallback from sagemaker.serve.detector.image_detector import ( auto_detect_container, _detect_framework_and_version, @@ -67,6 +70,9 @@ ModelServer.DJL_SERVING, } +MIB_CONVERSION_FACTOR = 0.00000095367431640625 +MEMORY_BUFFER_MULTIPLIER = 1.2 # 20% buffer + # pylint: disable=attribute-defined-outside-init @dataclass @@ -569,7 +575,7 @@ def wrapper(*args, **kwargs): # It supports two modes of deployment # 1/ SageMaker Endpoint # 2/ Local launch with container - def build( + def build( # pylint: disable=R0911 self, mode: Type[Mode] = None, role_arn: str = None, @@ -625,6 +631,13 @@ def build( if model_task == "text-generation": # pylint: disable=R1705 return self._build_for_tgi() + elif self._can_fit_on_single_gpu(): + return self._build_for_transformers() + elif ( + self.model in defaults.DEEPSPEED_RECOMMENDED_ARCHITECTURES + or self.model in defaults.FASTER_TRANSFORMER_RECOMMENDED_ARCHITECTURES + ): + return self._build_for_djl() else: return self._build_for_transformers() @@ -696,3 +709,66 @@ def _schema_builder_init(self, model_task: str): self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs) except ValueError: raise TaskNotFoundException(f"Schema builder for {model_task} could not be found.") + + def _total_inference_model_size_mib(self): + """Calculates the model size from HF accelerate + + This function gets the model size from accelerate. It also adds a + padding and converts to size MiB. When performing inference, expect + to add up to an additional 20% to the given model size as found by EleutherAI. + """ + dtypes = self.env_vars.get("dtypes", "float32") + parser = estimate_command_parser() + args = parser.parse_args([self.model, "--dtypes", dtypes]) + + output = gather_data( + args + ) # "dtype", "Largest Layer", "Total Size Bytes", "Training using Adam" + + if output is None: + raise ValueError(f"Could not get Model size for {self.model}") + + total_memory_size_mib = MEMORY_BUFFER_MULTIPLIER * output[0][2] * MIB_CONVERSION_FACTOR + logger.info("Total memory size MIB: %s", total_memory_size_mib) + return total_memory_size_mib + + def _can_fit_on_single_gpu(self) -> Type[bool]: + """Check if model can fit on a single GPU + + If the size of the model is <= single gpu memory size, returns True else False + """ + try: + single_gpu_size_mib = self._try_fetch_gpu_info() + if self._total_inference_model_size_mib() <= single_gpu_size_mib: + logger.info( + "Total inference model size MIB %s, single GPU size for instance MIB %s", + self._total_inference_model_size_mib(), + single_gpu_size_mib, + ) + return True + return False + except ValueError: + logger.info("Unable to determine single GPU size for instance %s", self.instance_type) + return False + + def _try_fetch_gpu_info(self): + """Get GPU info + + This function gets the GPU info or fallback to set the size of a single GPU + """ + try: + gpu_info = _get_gpu_info(self.instance_type, self.sagemaker_session) + logger.info("GPU info %s for instance %s", gpu_info, self.instance_type) + return gpu_info[1] / gpu_info[0] + except ValueError: + pass + try: + gpu_fallback = _get_gpu_info_fallback( + self.instance_type, self.sagemaker_session.boto_region_name + ) + logger.info("GPU fallback picked up %s", gpu_fallback) + return gpu_fallback[1] / gpu_fallback[0] + except ValueError: + raise ValueError( + f"Unable to determine single GPU size for instance: [{self.instance_type}]" + ) diff --git a/src/sagemaker/serve/schema/task.json b/src/sagemaker/serve/schema/task.json index 9ee6d186a2..1a7bdce5d0 100644 --- a/src/sagemaker/serve/schema/task.json +++ b/src/sagemaker/serve/schema/task.json @@ -2,7 +2,7 @@ "fill-mask": { "sample_inputs": { "properties": { - "inputs": "Paris is the of France.", + "inputs": "Paris is the [MASK] of France.", "parameters": {} } }, diff --git a/tests/integ/sagemaker/serve/test_serve_model_builder_gpu.py b/tests/integ/sagemaker/serve/test_serve_model_builder_gpu.py new file mode 100644 index 0000000000..933c18bacf --- /dev/null +++ b/tests/integ/sagemaker/serve/test_serve_model_builder_gpu.py @@ -0,0 +1,184 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +from sagemaker.serve.builder.schema_builder import SchemaBuilder +from sagemaker.serve.builder.model_builder import ModelBuilder, Mode +import tests.integ +from tests.integ.sagemaker.serve.constants import ( + HF_DIR, + PYTHON_VERSION_IS_NOT_310, + SERVE_SAGEMAKER_ENDPOINT_TIMEOUT, +) +from tests.integ.timeout import timeout +from tests.integ.utils import cleanup_model_resources, gpu_list, retry_with_instance_list +import logging + +logger = logging.getLogger(__name__) + +model_id = "bert-base-uncased" + +sample_input = {"inputs": "Hello I'm a [MASK] model."} + +sample_output = [ + { + "score": 0.10731109976768494, + "token": 4827, + "token_str": "fashion", + "sequence": "hello i'm a fashion model.", + }, + { + "score": 0.08774465322494507, + "token": 2535, + "token_str": "role", + "sequence": "hello i'm a role model.", + }, + { + "score": 0.05338414013385773, + "token": 2047, + "token_str": "new", + "sequence": "hello i'm a new model.", + }, + { + "score": 0.04667224362492561, + "token": 3565, + "token_str": "super", + "sequence": "hello i'm a super model.", + }, + { + "score": 0.027096163481473923, + "token": 2986, + "token_str": "fine", + "sequence": "hello i'm a fine model.", + }, +] + + +@pytest.fixture +def model_input(): + return {"inputs": "The man worked as a [MASK]."} + + +@pytest.fixture +def model_builder_model_schema_builder(): + return ModelBuilder( + model_path=HF_DIR, model=model_id, schema_builder=SchemaBuilder(sample_input, sample_output) + ) + + +@pytest.fixture +def model_builder(request): + return request.getfixturevalue(request.param) + + +@pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS + and tests.integ.test_region() in tests.integ.TRAINING_NO_P3_REGIONS, + reason="no ml.p2 or ml.p3 instances in this region", +) +@retry_with_instance_list(gpu_list(tests.integ.test_region())) +@pytest.mark.parametrize("model_builder", ["model_builder_model_schema_builder"], indirect=True) +def test_non_text_generation_model_single_GPU( + sagemaker_session, model_builder, model_input, **kwargs +): + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] + model = model_builder.build(role_arn=role_arn, sagemaker_session=sagemaker_session) + caught_ex = None + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + try: + logger.info("Running in SAGEMAKER_ENDPOINT mode") + predictor = model.deploy( + mode=Mode.SAGEMAKER_ENDPOINT, + instance_type=kwargs["instance_type"], + initial_instance_count=1, + ) + logger.info("Endpoint successfully deployed.") + prediction = predictor.predict(model_input) + assert prediction is not None + + endpoint_name = predictor.endpoint_name + sagemaker_client = sagemaker_session.boto_session.client("sagemaker") + endpoint_config_name = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)[ + "EndpointConfigName" + ] + actual_instance_type = sagemaker_client.describe_endpoint_config( + EndpointConfigName=endpoint_config_name + )["ProductionVariants"][0]["InstanceType"] + assert kwargs["instance_type"] == actual_instance_type + except Exception as e: + caught_ex = e + finally: + cleanup_model_resources( + sagemaker_session=model_builder.sagemaker_session, + model_name=model.name, + endpoint_name=model.endpoint_name, + ) + if caught_ex: + logger.exception(caught_ex) + assert ( + False + ), f"Exception {caught_ex} was thrown when running model builder single GPU test" + + +@pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS + and tests.integ.test_region() in tests.integ.TRAINING_NO_P3_REGIONS, + reason="no ml.p2 or ml.p3 instances in this region", +) +@retry_with_instance_list(gpu_list(tests.integ.test_region())) +@pytest.mark.parametrize("model_builder", ["model_builder_model_schema_builder"], indirect=True) +def test_non_text_generation_model_multi_GPU( + sagemaker_session, model_builder, model_input, **kwargs +): + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] + caught_ex = None + model = model_builder.build(role_arn=role_arn, sagemaker_session=sagemaker_session) + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + try: + logger.info("Running in SAGEMAKER_ENDPOINT mode") + predictor = model.deploy( + mode=Mode.SAGEMAKER_ENDPOINT, + instance_type=kwargs["instance_type"], + initial_instance_count=1, + ) + logger.info("Endpoint successfully deployed.") + prediction = predictor.predict(model_input) + assert prediction is not None + + endpoint_name = predictor.endpoint_name + sagemaker_client = sagemaker_session.boto_session.client("sagemaker") + endpoint_config_name = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)[ + "EndpointConfigName" + ] + actual_instance_type = sagemaker_client.describe_endpoint_config( + EndpointConfigName=endpoint_config_name + )["ProductionVariants"][0]["InstanceType"] + assert kwargs["instance_type"] == actual_instance_type + except Exception as e: + caught_ex = e + finally: + cleanup_model_resources( + sagemaker_session=model_builder.sagemaker_session, + model_name=model.name, + endpoint_name=model.endpoint_name, + ) + if caught_ex: + logger.exception(caught_ex) + assert ( + False + ), f"Exception {caught_ex} was thrown when running model builder multi GPU test" diff --git a/tests/integ/sagemaker/serve/test_serve_transformers.py b/tests/integ/sagemaker/serve/test_serve_transformers.py index 735f60d0f2..64029f7290 100644 --- a/tests/integ/sagemaker/serve/test_serve_transformers.py +++ b/tests/integ/sagemaker/serve/test_serve_transformers.py @@ -15,7 +15,7 @@ import pytest from sagemaker.serve.builder.schema_builder import SchemaBuilder from sagemaker.serve.builder.model_builder import ModelBuilder, Mode - +import tests.integ from tests.integ.sagemaker.serve.constants import ( HF_DIR, PYTHON_VERSION_IS_NOT_310, @@ -23,7 +23,7 @@ ) from tests.integ.timeout import timeout -from tests.integ.utils import cleanup_model_resources +from tests.integ.utils import cleanup_model_resources, gpu_list, retry_with_instance_list import logging logger = logging.getLogger(__name__) @@ -67,7 +67,7 @@ @pytest.fixture -def input(): +def model_input(): return {"inputs": "The man worked as a [MASK]."} @@ -87,11 +87,14 @@ def model_builder(request): @pytest.mark.skipif( PYTHON_VERSION_IS_NOT_310, - reason="Testing feature", + tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS + and tests.integ.test_region() in tests.integ.TRAINING_NO_P3_REGIONS, + reason="no ml.p2 or ml.p3 instances in this region", ) +@retry_with_instance_list(gpu_list(tests.integ.test_region())) @pytest.mark.parametrize("model_builder", ["model_builder_model_schema_builder"], indirect=True) def test_pytorch_transformers_sagemaker_endpoint( - sagemaker_session, model_builder, gpu_instance_type, input + sagemaker_session, model_builder, model_input, **kwargs ): logger.info("Running in SAGEMAKER_ENDPOINT mode...") caught_ex = None @@ -106,9 +109,12 @@ def test_pytorch_transformers_sagemaker_endpoint( with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): try: logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") - predictor = model.deploy(instance_type=gpu_instance_type, initial_instance_count=1) + predictor = model.deploy( + instance_type=kwargs["instance_type"], initial_instance_count=2 + ) logger.info("Endpoint successfully deployed.") - predictor.predict(input) + predictor.predict(model_input) + assert predictor is not None except Exception as e: caught_ex = e finally: diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index becf63ab41..1f743ff442 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -38,6 +38,7 @@ version = "version" ENV_VARS = {"some key": "some value", "MODEL_CLASS_NAME": f"{module}.{class_name}"} ENV_VARS_INF_SPEC = {"some key": "some value"} +INSTANCE_GPU_INFO = (2, 8) mock_image_uri = "abcd/efghijk" mock_1p_dlc_image_uri = "763104351884.dkr.ecr.us-east-1.amazonaws.com" @@ -52,6 +53,9 @@ ModelServer.DJL_SERVING, } +MIB_CONVERSION_FACTOR = 0.00000095367431640625 +MEMORY_BUFFER_MULTIPLIER = 1.2 # 20% buffer + mock_session = MagicMock() @@ -1077,3 +1081,515 @@ def test_build_negative_path_when_schema_builder_not_present( "Error Message: Schema builder for text-to-image could not be found.", lambda: model_builder.build(sagemaker_session=mock_session), ) + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock()) + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._can_fit_on_single_gpu") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_can_fit_on_single_gpu( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + mock_can_fit_on_single_gpu, + ): + # Setup mocks + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + # HF Pipeline Tag + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-classification"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + # HF Model config + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + + model_builder = ModelBuilder(model="meta-llama/Llama-2-7b-hf") + model_builder.build(sagemaker_session=mock_session) + + mock_can_fit_on_single_gpu.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._can_fit_on_single_gpu") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_is_deepspeed_model( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + mock_can_fit_on_single_gpu, + mock_build_for_djl, + ): + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-classification"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + mock_can_fit_on_single_gpu.return_value = False + + model_builder = ModelBuilder(model="stable-diffusion") + model_builder.build(sagemaker_session=mock_session) + + mock_build_for_djl.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._can_fit_on_single_gpu") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_for_transformers_happy_case( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + mock_can_fit_on_single_gpu, + mock_build_for_transformers, + ): + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-classification"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + mock_can_fit_on_single_gpu.return_value = True + + model_builder = ModelBuilder(model="stable-diffusion") + model_builder.build(sagemaker_session=mock_session) + + mock_build_for_transformers.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._try_fetch_gpu_info") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._total_inference_model_size_mib") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_for_transformers_happy_case_with_values( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + mock_total_inference_model_size_mib, + mock_try_fetch_gpu_info, + mock_build_for_transformers, + ): + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-classification"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + mock_try_fetch_gpu_info.return_value = 2 + mock_total_inference_model_size_mib.return_value = 2 + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + + model_builder = ModelBuilder(model="stable-diffusion") + model_builder.build(sagemaker_session=mock_session) + + mock_build_for_transformers.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl", Mock()) + @patch("sagemaker.serve.builder.model_builder._get_gpu_info") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._total_inference_model_size_mib") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_for_transformers_happy_case_with_valid_gpu_info( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + mock_total_inference_model_size_mib, + mock_try_fetch_gpu_info, + ): + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-classification"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + mock_try_fetch_gpu_info.return_value = INSTANCE_GPU_INFO + mock_total_inference_model_size_mib.return_value = 10_000 + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + + model_builder = ModelBuilder(model="stable-diffusion") + model_builder.build(sagemaker_session=mock_session) + self.assertEqual( + model_builder._try_fetch_gpu_info(), INSTANCE_GPU_INFO[1] / INSTANCE_GPU_INFO[0] + ) + self.assertEqual(model_builder._can_fit_on_single_gpu(), False) + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock()) + @patch("sagemaker.serve.builder.model_builder._get_gpu_info") + @patch("sagemaker.serve.builder.model_builder._get_gpu_info_fallback") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._total_inference_model_size_mib") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_for_transformers_happy_case_with_valid_gpu_fallback( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + mock_total_inference_model_size_mib, + mock_gpu_fallback, + mock_try_fetch_gpu_info, + ): + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-classification"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + mock_try_fetch_gpu_info.side_effect = ValueError + mock_gpu_fallback.return_value = INSTANCE_GPU_INFO + mock_total_inference_model_size_mib.return_value = ( + INSTANCE_GPU_INFO[1] / INSTANCE_GPU_INFO[0] - 1 + ) + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + + model_builder = ModelBuilder( + model="stable-diffusion", + sagemaker_session=mock_session, + instance_type=mock_instance_type, + ) + self.assertEqual( + model_builder._try_fetch_gpu_info(), INSTANCE_GPU_INFO[1] / INSTANCE_GPU_INFO[0] + ) + self.assertEqual(model_builder._can_fit_on_single_gpu(), True) + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock()) + @patch("sagemaker.serve.builder.model_builder.estimate_command_parser") + @patch("sagemaker.serve.builder.model_builder.gather_data") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_for_transformers_happy_case_hugging_face_responses( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + mock_gather_data, + mock_parser, + ): + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-classification"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + mock_image_uris_retrieve.return_value = "https://some-image-uri" + + mock_parser.return_value = Mock() + mock_gather_data.return_value = [[1, 1, 1, 1]] + product = MIB_CONVERSION_FACTOR * 1 * MEMORY_BUFFER_MULTIPLIER + + model_builder = ModelBuilder( + model="stable-diffusion", + sagemaker_session=mock_session, + instance_type=mock_instance_type, + ) + self.assertEqual(model_builder._total_inference_model_size_mib(), product) + + mock_parser.return_value = Mock() + mock_gather_data.return_value = None + model_builder = ModelBuilder( + model="stable-diffusion", + sagemaker_session=mock_session, + instance_type=mock_instance_type, + ) + with self.assertRaises(ValueError) as _: + model_builder._total_inference_model_size_mib() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._can_fit_on_single_gpu") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_is_fast_transformers_model( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + mock_can_fit_on_single_gpu, + mock_build_for_djl, + ): + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-classification"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + mock_can_fit_on_single_gpu.return_value = False + + model_builder = ModelBuilder(model="gpt_neo") + model_builder.build(sagemaker_session=mock_session) + + mock_build_for_djl.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._can_fit_on_single_gpu") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_fallback_to_transformers( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + mock_can_fit_on_single_gpu, + mock_build_for_transformers, + ): + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-classification"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + mock_build_for_transformers.side_effect = Mock() + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + mock_can_fit_on_single_gpu.return_value = False + + model_builder = ModelBuilder(model="gpt_llm_burt") + model_builder.build(sagemaker_session=mock_session) + + mock_build_for_transformers.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tgi") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_text_generation( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + mock_build_for_tgi, + ): + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-generation"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + mock_build_for_tgi.side_effect = Mock() + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + + model_builder = ModelBuilder(model="bloom-560m") + model_builder.build(sagemaker_session=mock_session) + + mock_build_for_tgi.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock()) + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._try_fetch_gpu_info") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_try_fetch_gpu_info_throws( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + mock_can_fit_on_single_gpu, + ): + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-classification"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + mock_can_fit_on_single_gpu.side_effect = ValueError + + model_builder = ModelBuilder(model="gpt_llm_burt") + model_builder.build(sagemaker_session=mock_session) + + self.assertEqual(model_builder._can_fit_on_single_gpu(), False) + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock()) + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._total_inference_model_size_mib") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_total_inference_model_size_mib_throws( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + mock_total_inference_model_size_mib, + ): + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-classification"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + mock_total_inference_model_size_mib.side_effect = ValueError + + model_builder = ModelBuilder(model="gpt_llm_burt") + model_builder.build(sagemaker_session=mock_session) + + self.assertEqual(model_builder._can_fit_on_single_gpu(), False) diff --git a/tests/unit/sagemaker/serve/utils/test_task.py b/tests/unit/sagemaker/serve/utils/test_task.py index 78553968e1..431888e249 100644 --- a/tests/unit/sagemaker/serve/utils/test_task.py +++ b/tests/unit/sagemaker/serve/utils/test_task.py @@ -18,7 +18,7 @@ from sagemaker.serve.utils import task -EXPECTED_INPUTS = {"inputs": "Paris is the of France.", "parameters": {}} +EXPECTED_INPUTS = {"inputs": "Paris is the [MASK] of France.", "parameters": {}} EXPECTED_OUTPUTS = [{"sequence": "Paris is the capital of France.", "score": 0.7}] HF_INVALID_TASK = "not-present-task" From 45a471fd1934439e76f12423b5ea14e2a57782da Mon Sep 17 00:00:00 2001 From: adtian2 <55163384+adtian2@users.noreply.github.com> Date: Tue, 5 Mar 2024 21:37:38 -0800 Subject: [PATCH 18/42] change: Upgrade smp to version 2.2 (#4479) * upgrading smp to version 2.2 * fixing linting issue * fixing syntax error with multiline if statement * upgrading smp to version 2.2 * fixing linting issue * fixing syntax error with multiline if statement * fixing formatting --------- Co-authored-by: Andrew Tian --- src/sagemaker/fw_utils.py | 10 ++++++- .../image_uri_config/pytorch-smp.json | 28 ++++++++++++++++++- src/sagemaker/image_uris.py | 6 +++- .../unit/sagemaker/image_uris/test_smp_v2.py | 2 +- 4 files changed, 42 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 1fb8c8eaaa..449beeb55d 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -141,6 +141,7 @@ "2.0.1", "2.1.0", "2.1.2", + "2.2.0", ], } @@ -160,7 +161,14 @@ ] -TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.2"] +TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = [ + "1.13.1", + "2.0.0", + "2.0.1", + "2.1.0", + "2.1.2", + "2.2.0", +] TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"] TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS = [ diff --git a/src/sagemaker/image_uri_config/pytorch-smp.json b/src/sagemaker/image_uri_config/pytorch-smp.json index d71c2df6ec..933c0fa437 100644 --- a/src/sagemaker/image_uri_config/pytorch-smp.json +++ b/src/sagemaker/image_uri_config/pytorch-smp.json @@ -5,7 +5,8 @@ ], "version_aliases": { "2.0": "2.0.1", - "2.1": "2.1.2" + "2.1": "2.1.2", + "2.2": "2.2.0" }, "versions": { "2.0.1": { @@ -57,6 +58,31 @@ "us-west-2": "658645717510" }, "repository": "smdistributed-modelparallel" + }, + "2.2.0": { + "py_versions": [ + "py310" + ], + "registries": { + "ap-northeast-1": "658645717510", + "ap-northeast-2": "658645717510", + "ap-northeast-3": "658645717510", + "ap-south-1": "658645717510", + "ap-southeast-1": "658645717510", + "ap-southeast-2": "658645717510", + "ca-central-1": "658645717510", + "eu-central-1": "658645717510", + "eu-north-1": "658645717510", + "eu-west-1": "658645717510", + "eu-west-2": "658645717510", + "eu-west-3": "658645717510", + "sa-east-1": "658645717510", + "us-east-1": "658645717510", + "us-east-2": "658645717510", + "us-west-1": "658645717510", + "us-west-2": "658645717510" + }, + "repository": "smdistributed-modelparallel" } } } diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 99692d0f8b..4004147fd9 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -682,7 +682,11 @@ def get_training_image_uri( if "modelparallel" in distribution["smdistributed"]: if distribution["smdistributed"]["modelparallel"].get("enabled", True): framework = "pytorch-smp" - if "p5" in instance_type or "2.1" in framework_version: + if ( + "p5" in instance_type + or "2.1" in framework_version + or "2.2" in framework_version + ): container_version = "cu121" else: container_version = "cu118" diff --git a/tests/unit/sagemaker/image_uris/test_smp_v2.py b/tests/unit/sagemaker/image_uris/test_smp_v2.py index 36accdebbb..b53a45133e 100644 --- a/tests/unit/sagemaker/image_uris/test_smp_v2.py +++ b/tests/unit/sagemaker/image_uris/test_smp_v2.py @@ -35,7 +35,7 @@ def test_smp_v2(load_config): for region in ACCOUNTS.keys(): for instance_type in CONTAINER_VERSIONS.keys(): cuda_vers = CONTAINER_VERSIONS[instance_type] - if "2.1" in version: + if "2.1" in version or "2.2" in version: cuda_vers = "cu121" uri = image_uris.get_training_image_uri( From b426c21a41b193972072be6b16e65d47dfafa935 Mon Sep 17 00:00:00 2001 From: Sirut Buasai <73297481+sirutBuasai@users.noreply.github.com> Date: Wed, 6 Mar 2024 08:26:56 -0800 Subject: [PATCH 19/42] feat: Update SM Python SDK for PT 2.2.0 SM DLC (#4481) * update pt2.2 sm training dlc pysdk * update pt2.2 sm inference dlc pysdk and region list --- src/sagemaker/fw_utils.py | 1 + src/sagemaker/image_uri_config/pytorch.json | 101 +++++++++++++++++++- tests/unit/test_fw_utils.py | 3 +- tests/unit/test_utils.py | 2 + 4 files changed, 104 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 449beeb55d..ca5c09a96c 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -158,6 +158,7 @@ "2.0.0", "2.0.1", "2.1.0", + "2.2.0", ] diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index 65b8513c0a..85b454ebed 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -848,6 +848,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -856,11 +857,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", @@ -887,6 +890,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -895,11 +899,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", @@ -926,6 +932,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -934,11 +941,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", @@ -965,6 +974,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -973,11 +983,55 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + }, + "2.2.0": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", @@ -1190,7 +1244,8 @@ "1.12": "1.12.1", "1.13": "1.13.1", "2.0": "2.0.1", - "2.1": "2.1.0" + "2.1": "2.1.0", + "2.2": "2.2.0" }, "versions": { "0.4.0": { @@ -2113,7 +2168,49 @@ "ca-west-1": "204538143572" }, "repository": "pytorch-training" + }, + "2.2.0": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" } } } -} \ No newline at end of file +} diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 4600785159..ae4cfe8ab5 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -998,6 +998,7 @@ def test_validate_pytorchddp_not_raises(): "2.0.0", "2.0.1", "2.1.0", + "2.2.0", ] for framework_version in pytorchddp_supported_fw_versions: fw_utils.validate_pytorch_distribution( @@ -1060,7 +1061,7 @@ def test_validate_torch_distributed_not_raises(): # Case 3: Distribution is torch_distributed enabled, supported framework and instances torch_distributed_enabled = {"torch_distributed": {"enabled": True}} - torch_distributed_gpu_supported_fw_versions = ["1.13.1", "2.0.0", "2.0.1", "2.1.0"] + torch_distributed_gpu_supported_fw_versions = ["1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.2.0"] for framework_version in torch_distributed_gpu_supported_fw_versions: fw_utils.validate_torch_distributed_distribution( instance_type="ml.p3.8xlarge", diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 8488a8308e..b5376d3556 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -384,6 +384,8 @@ def test_set_nested_value(): def test_get_short_version(): + assert sagemaker.utils.get_short_version("2.2.0") == "2.2" + assert sagemaker.utils.get_short_version("2.2") == "2.2" assert sagemaker.utils.get_short_version("2.1.0") == "2.1" assert sagemaker.utils.get_short_version("2.1") == "2.1" assert sagemaker.utils.get_short_version("2.0.1") == "2.0" From 7000f2566427051e227526df9d1801d583fcf0b7 Mon Sep 17 00:00:00 2001 From: Kalyani Nikure <110067132+knikure@users.noreply.github.com> Date: Wed, 6 Mar 2024 09:08:12 -0800 Subject: [PATCH 20/42] fix: Create custom tarfile extractall util to fix backward compatibility issue (#4476) * fix: Create custom tarfile extractall util to fix backward compatibility issue * Address review comments * fix logger.error statements --- src/sagemaker/local/image.py | 5 +- .../serve/model_server/djl_serving/prepare.py | 5 +- .../serve/model_server/tgi/prepare.py | 5 +- src/sagemaker/utils.py | 107 ++++++++++++++---- src/sagemaker/workflow/_repack_model.py | 98 +++++++++++++++- src/sagemaker/workflow/_utils.py | 5 +- tests/integ/s3_utils.py | 5 +- tests/unit/test_fw_utils.py | 5 +- tests/unit/test_utils.py | 56 +++++++-- 9 files changed, 240 insertions(+), 51 deletions(-) diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index 39c879ef6d..377bdcac85 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -40,7 +40,7 @@ import sagemaker.local.data import sagemaker.local.utils import sagemaker.utils -from sagemaker.utils import check_tarfile_data_filter_attribute +from sagemaker.utils import custom_extractall_tarfile CONTAINER_PREFIX = "algo" STUDIO_HOST_NAME = "sagemaker-local" @@ -687,8 +687,7 @@ def _prepare_serving_volumes(self, model_location): for filename in model_data_source.get_file_list(): if tarfile.is_tarfile(filename): with tarfile.open(filename) as tar: - check_tarfile_data_filter_attribute() - tar.extractall(path=model_data_source.get_root_dir(), filter="data") + custom_extractall_tarfile(tar, model_data_source.get_root_dir()) volumes.append(_Volume(model_data_source.get_root_dir(), "/opt/ml/model")) diff --git a/src/sagemaker/serve/model_server/djl_serving/prepare.py b/src/sagemaker/serve/model_server/djl_serving/prepare.py index 6bdada0b6c..810acc8aff 100644 --- a/src/sagemaker/serve/model_server/djl_serving/prepare.py +++ b/src/sagemaker/serve/model_server/djl_serving/prepare.py @@ -20,7 +20,7 @@ from typing import List from pathlib import Path -from sagemaker.utils import _tmpdir, check_tarfile_data_filter_attribute +from sagemaker.utils import _tmpdir, custom_extractall_tarfile from sagemaker.s3 import S3Downloader from sagemaker.djl_inference import DJLModel from sagemaker.djl_inference.model import _read_existing_serving_properties @@ -53,8 +53,7 @@ def _extract_js_resource(js_model_dir: str, js_id: str): """Uncompress the jumpstart resource""" tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz") with tarfile.open(str(tmp_sourcedir)) as resources: - check_tarfile_data_filter_attribute() - resources.extractall(path=js_model_dir, filter="data") + custom_extractall_tarfile(resources, js_model_dir) def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path): diff --git a/src/sagemaker/serve/model_server/tgi/prepare.py b/src/sagemaker/serve/model_server/tgi/prepare.py index 9b187dd2ed..af09515da9 100644 --- a/src/sagemaker/serve/model_server/tgi/prepare.py +++ b/src/sagemaker/serve/model_server/tgi/prepare.py @@ -19,7 +19,7 @@ from pathlib import Path from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage -from sagemaker.utils import _tmpdir, check_tarfile_data_filter_attribute +from sagemaker.utils import _tmpdir, custom_extractall_tarfile from sagemaker.s3 import S3Downloader logger = logging.getLogger(__name__) @@ -29,8 +29,7 @@ def _extract_js_resource(js_model_dir: str, code_dir: Path, js_id: str): """Uncompress the jumpstart resource""" tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz") with tarfile.open(str(tmp_sourcedir)) as resources: - check_tarfile_data_filter_attribute() - resources.extractall(path=code_dir, filter="data") + custom_extractall_tarfile(resources, code_dir) def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bool: diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index a6d26db48b..115b8b258d 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -22,7 +22,6 @@ import random import re import shutil -import sys import tarfile import tempfile import time @@ -31,6 +30,7 @@ import abc import uuid from datetime import datetime +from os.path import abspath, realpath, dirname, normpath, join as joinpath from importlib import import_module import botocore @@ -592,8 +592,7 @@ def _create_or_update_code_dir( download_file_from_url(source_directory, local_code_path, sagemaker_session) with tarfile.open(name=local_code_path, mode="r:gz") as t: - check_tarfile_data_filter_attribute() - t.extractall(path=code_dir, filter="data") + custom_extractall_tarfile(t, code_dir) elif source_directory: if os.path.exists(code_dir): @@ -630,8 +629,7 @@ def _extract_model(model_uri, sagemaker_session, tmp): else: local_model_path = model_uri.replace("file://", "") with tarfile.open(name=local_model_path, mode="r:gz") as t: - check_tarfile_data_filter_attribute() - t.extractall(path=tmp_model_dir, filter="data") + custom_extractall_tarfile(t, tmp_model_dir) return tmp_model_dir @@ -1494,23 +1492,92 @@ def format_tags(tags: Tags) -> List[TagsDict]: return tags -class PythonVersionError(Exception): - """Raise when a secure [/patched] version of Python is not used.""" +def _get_resolved_path(path): + """Return the normalized absolute path of a given path. + abspath - returns the absolute path without resolving symlinks + realpath - resolves the symlinks and gets the actual path + normpath - normalizes paths (e.g. remove redudant separators) + and handles platform-specific differences + """ + return normpath(realpath(abspath(path))) -def check_tarfile_data_filter_attribute(): - """Check if tarfile has data_filter utility. - Tarfile-data_filter utility has guardrails against untrusted de-serialisation. +def _is_bad_path(path, base): + """Checks if the joined path (base directory + file path) is rooted under the base directory - Raises: - PythonVersionError: if `tarfile.data_filter` is not available. + Ensuring that the file does not attempt to access paths + outside the expected directory structure. + + Args: + path (str): The file path. + base (str): The base directory. + + Returns: + bool: True if the path is not rooted under the base directory, False otherwise. """ - # The function and it's usages can be deprecated post support of python >= 3.12 - if not hasattr(tarfile, "data_filter"): - raise PythonVersionError( - f"Since tarfile extraction is unsafe the operation is prohibited " - f"per PEP-721. Please update your Python [{sys.version}] " - f"to latest patch [refer to https://www.python.org/downloads/] " - f"to consume the security patch" - ) + # joinpath will ignore base if path is absolute + return not _get_resolved_path(joinpath(base, path)).startswith(base) + + +def _is_bad_link(info, base): + """Checks if the link is rooted under the base directory. + + Ensuring that the link does not attempt to access paths outside the expected directory structure + + Args: + info (tarfile.TarInfo): The tar file info. + base (str): The base directory. + + Returns: + bool: True if the link is not rooted under the base directory, False otherwise. + """ + # Links are interpreted relative to the directory containing the link + tip = _get_resolved_path(joinpath(base, dirname(info.name))) + return _is_bad_path(info.linkname, base=tip) + + +def _get_safe_members(members): + """A generator that yields members that are safe to extract. + + It filters out bad paths and bad links. + + Args: + members (list): A list of members to check. + + Yields: + tarfile.TarInfo: The tar file info. + """ + base = _get_resolved_path(".") + + for file_info in members: + if _is_bad_path(file_info.name, base): + logger.error("%s is blocked (illegal path)", file_info.name) + elif file_info.issym() and _is_bad_link(file_info, base): + logger.error("%s is blocked: Symlink to %s", file_info.name, file_info.linkname) + elif file_info.islnk() and _is_bad_link(file_info, base): + logger.error("%s is blocked: Hard link to %s", file_info.name, file_info.linkname) + else: + yield file_info + + +def custom_extractall_tarfile(tar, extract_path): + """Extract a tarfile, optionally using data_filter if available. + + # TODO: The function and it's usages can be deprecated once SageMaker Python SDK + is upgraded to use Python 3.12+ + + If the tarfile has a data_filter attribute, it will be used to extract the contents of the file. + Otherwise, the _get_safe_members function will be used to filter bad paths and bad links. + + Args: + tar (tarfile.TarFile): The opened tarfile object. + extract_path (str): The path to extract the contents of the tarfile. + + Returns: + None + """ + if hasattr(tarfile, "data_filter"): + tar.extractall(path=extract_path, filter="data") + else: + tar.extractall(path=extract_path, members=_get_safe_members(tar)) diff --git a/src/sagemaker/workflow/_repack_model.py b/src/sagemaker/workflow/_repack_model.py index 3cfa6760b3..84b3a426f6 100644 --- a/src/sagemaker/workflow/_repack_model.py +++ b/src/sagemaker/workflow/_repack_model.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import argparse +import logging import os import shutil import tarfile @@ -33,6 +34,101 @@ # repacking is some short-lived hackery, right?? from distutils.dir_util import copy_tree +from os.path import abspath, realpath, dirname, normpath, join as joinpath + +logger = logging.getLogger(__name__) + + +def _get_resolved_path(path): + """Return the normalized absolute path of a given path. + + abspath - returns the absolute path without resolving symlinks + realpath - resolves the symlinks and gets the actual path + normpath - normalizes paths (e.g. remove redudant separators) + and handles platform-specific differences + """ + return normpath(realpath(abspath(path))) + + +def _is_bad_path(path, base): + """Checks if the joined path (base directory + file path) is rooted under the base directory + + Ensuring that the file does not attempt to access paths + outside the expected directory structure. + + Args: + path (str): The file path. + base (str): The base directory. + + Returns: + bool: True if the path is not rooted under the base directory, False otherwise. + """ + # joinpath will ignore base if path is absolute + return not _get_resolved_path(joinpath(base, path)).startswith(base) + + +def _is_bad_link(info, base): + """Checks if the link is rooted under the base directory. + + Ensuring that the link does not attempt to access paths outside the expected directory structure + + Args: + info (tarfile.TarInfo): The tar file info. + base (str): The base directory. + + Returns: + bool: True if the link is not rooted under the base directory, False otherwise. + """ + # Links are interpreted relative to the directory containing the link + tip = _get_resolved_path(joinpath(base, dirname(info.name))) + return _is_bad_path(info.linkname, base=tip) + + +def _get_safe_members(members): + """A generator that yields members that are safe to extract. + + It filters out bad paths and bad links. + + Args: + members (list): A list of members to check. + + Yields: + tarfile.TarInfo: The tar file info. + """ + base = _get_resolved_path(".") + + for file_info in members: + if _is_bad_path(file_info.name, base): + logger.error("%s is blocked (illegal path)", file_info.name) + elif file_info.issym() and _is_bad_link(file_info, base): + logger.error("%s is blocked: Symlink to %s", file_info.name, file_info.linkname) + elif file_info.islnk() and _is_bad_link(file_info, base): + logger.error("%s is blocked: Hard link to %s", file_info.name, file_info.linkname) + else: + yield file_info + + +def custom_extractall_tarfile(tar, extract_path): + """Extract a tarfile, optionally using data_filter if available. + + # TODO: The function and it's usages can be deprecated once SageMaker Python SDK + is upgraded to use Python 3.12+ + + If the tarfile has a data_filter attribute, it will be used to extract the contents of the file. + Otherwise, the _get_safe_members function will be used to filter bad paths and bad links. + + Args: + tar (tarfile.TarFile): The opened tarfile object. + extract_path (str): The path to extract the contents of the tarfile. + + Returns: + None + """ + if hasattr(tarfile, "data_filter"): + tar.extractall(path=extract_path, filter="data") + else: + tar.extractall(path=extract_path, members=_get_safe_members(tar)) + def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover """Repack custom dependencies and code into an existing model TAR archive @@ -60,7 +156,7 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None): # extract the contents of the previous training job's model archive to the "src" # directory of this training job with tarfile.open(name=local_path, mode="r:gz") as tf: - tf.extractall(path=src_dir) + custom_extractall_tarfile(tf, src_dir) if source_dir: # copy /opt/ml/code to code/ diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 1b88bfd924..1fafa646bf 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -36,7 +36,7 @@ _save_model, download_file_from_url, format_tags, - check_tarfile_data_filter_attribute, + custom_extractall_tarfile, ) from sagemaker.workflow.retry import RetryPolicy from sagemaker.workflow.utilities import trim_request_dict @@ -262,8 +262,7 @@ def _inject_repack_script_and_launcher(self): download_file_from_url(self._source_dir, old_targz_path, self.sagemaker_session) with tarfile.open(name=old_targz_path, mode="r:gz") as t: - check_tarfile_data_filter_attribute() - t.extractall(path=targz_contents_dir, filter="data") + custom_extractall_tarfile(t, targz_contents_dir) shutil.copy2(fname, os.path.join(targz_contents_dir, REPACK_SCRIPT)) with open( diff --git a/tests/integ/s3_utils.py b/tests/integ/s3_utils.py index 500dc4a33a..d5839c8409 100644 --- a/tests/integ/s3_utils.py +++ b/tests/integ/s3_utils.py @@ -19,7 +19,7 @@ import boto3 from six.moves.urllib.parse import urlparse -from sagemaker.utils import check_tarfile_data_filter_attribute +from sagemaker.utils import custom_extractall_tarfile def assert_s3_files_exist(sagemaker_session, s3_url, files): @@ -57,5 +57,4 @@ def extract_files_from_s3(s3_url, tmpdir, sagemaker_session): s3.Bucket(parsed_url.netloc).download_file(parsed_url.path.lstrip("/"), model) with tarfile.open(model, "r") as tar_file: - check_tarfile_data_filter_attribute() - tar_file.extractall(tmpdir, filter="data") + custom_extractall_tarfile(tar_file, tmpdir) diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index ae4cfe8ab5..7fa84acf1a 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -24,7 +24,7 @@ from mock import Mock, patch from sagemaker import fw_utils -from sagemaker.utils import name_from_image, check_tarfile_data_filter_attribute +from sagemaker.utils import name_from_image, custom_extractall_tarfile from sagemaker.session_settings import SessionSettings from sagemaker.instance_group import InstanceGroup @@ -424,8 +424,7 @@ def list_tar_files(folder, tar_ball, tmpdir): startpath = str(tmpdir.ensure(folder, dir=True)) with tarfile.open(name=tar_ball, mode="r:gz") as t: - check_tarfile_data_filter_attribute() - t.extractall(path=startpath, filter="data") + custom_extractall_tarfile(t, startpath) def walk(): for root, dirs, files in os.walk(startpath): diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index b5376d3556..852ef8b153 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -42,8 +42,10 @@ resolve_nested_dict_value_from_config, update_list_of_dicts_with_values_from_config, volume_size_supported, - PythonVersionError, - check_tarfile_data_filter_attribute, + _get_resolved_path, + _is_bad_path, + _is_bad_link, + custom_extractall_tarfile, ) from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.parameters import ParameterString, ParameterInteger @@ -1078,7 +1080,7 @@ def list_tar_files(tar_ball, tmp): os.mkdir(startpath) with tarfile.open(name=tar_ball, mode="r:gz") as t: - t.extractall(path=startpath) + custom_extractall_tarfile(t, startpath) def walk(): for root, dirs, files in os.walk(startpath): @@ -1754,13 +1756,43 @@ def test_instance_family_from_full_instance_type(self): self.assertEqual(family, get_instance_type_family(instance_type)) -class TestCheckTarfileDataFilterAttribute(TestCase): - def test_check_tarfile_data_filter_attribute_unhappy_case(self): - with pytest.raises(PythonVersionError): - with patch("tarfile.data_filter", None): - delattr(tarfile, "data_filter") - check_tarfile_data_filter_attribute() +@pytest.fixture +def mock_custom_tarfile(): + class MockTarfile: + def __init__(self, data_filter=False): + self.data_filter = data_filter - def test_check_tarfile_data_filter_attribute_happy_case(self): - with patch("tarfile.data_filter", "some_value"): - check_tarfile_data_filter_attribute() + def extractall(self, path, members=None, filter=None): + assert path == "/extract/path" + if members is not None: + assert next(members).name == "file.txt" + + return MockTarfile + + +def test_get_resolved_path(): + assert _get_resolved_path("path/to/file") == os.path.normpath( + os.path.realpath(os.path.abspath("path/to/file")) + ) + + +@pytest.mark.parametrize("file_path, base, expected", [("file.txt", "/path/to/base", False)]) +def test_is_bad_path(file_path, base, expected): + assert _is_bad_path(file_path, base) == expected + + +@pytest.mark.parametrize( + "link_name, base, expected", [("link_to_file.txt", "/path/to/base", False)] +) +def test_is_bad_link(link_name, base, expected): + dummy_info = tarfile.TarInfo(name="dummy.txt") + dummy_info.linkname = link_name + assert _is_bad_link(dummy_info, base) == expected + + +@pytest.mark.parametrize( + "data_filter, expected_extract_path", [(True, "/extract/path"), (False, "/extract/path")] +) +def test_custom_extractall_tarfile(mock_custom_tarfile, data_filter, expected_extract_path): + tar = mock_custom_tarfile(data_filter) + custom_extractall_tarfile(tar, "/extract/path") From 48d501f2a21f6c63f14e33b3e080f0b5c28dc74c Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 6 Mar 2024 19:57:29 +0000 Subject: [PATCH 21/42] prepare release v2.212.0 --- CHANGELOG.md | 12 ++++++++++++ VERSION | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4f42089c8c..699a44c787 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## v2.212.0 (2024-03-06) + +### Features + + * Update SM Python SDK for PT 2.2.0 SM DLC + +### Bug Fixes and Other Changes + + * Create custom tarfile extractall util to fix backward compatibility issue + * Upgrade smp to version 2.2 + * Enhance model builder selection logic to include model size + ## v2.211.0 (2024-03-05) ### Features diff --git a/VERSION b/VERSION index e2c63e5edf..2fdb49ca3f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.211.1.dev0 +2.212.0 From 4c8e0fc813c96253ca91465df7d1d5c7fe25c4f3 Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 6 Mar 2024 19:57:30 +0000 Subject: [PATCH 22/42] update development version to v2.212.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 2fdb49ca3f..4e29bf93f0 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.212.0 +2.212.1.dev0 From 65f2ddf9b9a5ce4bcdd158404dcff7a86a23138c Mon Sep 17 00:00:00 2001 From: Danny Bushkanets Date: Wed, 6 Mar 2024 21:35:50 -0500 Subject: [PATCH 23/42] change: Update tblib constraint (#4452) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 5b8845efed..f37058eb4a 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ def read_requirements(filename): "PyYAML~=6.0", "jsonschema", "platformdirs", - "tblib>=1.7.0,<3", + "tblib>=1.7.0,<4", "urllib3>=1.26.8,<3.0.0", "requests", "docker", From 3e9e04df7a821b9b8691d566ee7568d6d840585d Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> Date: Thu, 7 Mar 2024 15:18:42 -0800 Subject: [PATCH 24/42] fix: make unit tests compatible with pytest-xdist (#4486) * fix: make unit tests compatible with pytest-xdist * fix failing test --- .../feature_processor/test_data_helpers.py | 14 ++++++++++++++ .../feature_processor/test_validation.py | 4 +--- .../remote_function/core/test_stored_function.py | 3 +-- .../test_huggingface_pytorch_compiler.py | 6 +++--- .../test_huggingface_tensorflow_compiler.py | 4 ++-- .../training_compiler/test_pytorch_compiler.py | 6 +++--- .../training_compiler/test_tensorflow_compiler.py | 8 +++++--- .../unit/sagemaker/workflow/test_training_step.py | 4 ++++ 8 files changed, 33 insertions(+), 16 deletions(-) diff --git a/tests/unit/sagemaker/feature_store/feature_processor/test_data_helpers.py b/tests/unit/sagemaker/feature_store/feature_processor/test_data_helpers.py index 9c4f0fef49..bd572c7694 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/test_data_helpers.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/test_data_helpers.py @@ -52,6 +52,20 @@ "some-other-key": {"some-key": "some-value"}, } +DATA_SOURCE_UNIQUE_ID_TOO_LONG = """ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +""" + DESCRIBE_FEATURE_GROUP_RESPONSE = { "FeatureGroupArn": INPUT_FEATURE_GROUP_ARN, "FeatureGroupName": INPUT_FEATURE_GROUP_NAME, diff --git a/tests/unit/sagemaker/feature_store/feature_processor/test_validation.py b/tests/unit/sagemaker/feature_store/feature_processor/test_validation.py index 8e0115afd2..b0fde3274b 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/test_validation.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/test_validation.py @@ -19,8 +19,6 @@ import pytest import test_data_helpers as tdh -import string -import random from mock import Mock from sagemaker.feature_store.feature_processor._validation import ( @@ -164,7 +162,7 @@ def invalid_spark_position(spark, fg_data_source, s3_data_source): ("", "unique_id", "data_source_name of input does not match pattern '.*'."), ( "source", - "".join(random.choices(string.ascii_uppercase, k=2050)), + tdh.DATA_SOURCE_UNIQUE_ID_TOO_LONG, "data_source_unique_id of input does not match pattern '.*'.", ), ("source", "", "data_source_unique_id of input does not match pattern '.*'."), diff --git a/tests/unit/sagemaker/remote_function/core/test_stored_function.py b/tests/unit/sagemaker/remote_function/core/test_stored_function.py index b263682641..68a05c08a6 100644 --- a/tests/unit/sagemaker/remote_function/core/test_stored_function.py +++ b/tests/unit/sagemaker/remote_function/core/test_stored_function.py @@ -39,7 +39,6 @@ from sagemaker.workflow.function_step import _FunctionStep, DelayedReturn from sagemaker.workflow.parameters import ParameterFloat -from sagemaker.utils import sagemaker_timestamp from tests.unit.sagemaker.experiments.helpers import ( TEST_EXP_DISPLAY_NAME, @@ -55,7 +54,7 @@ FUNCTION_FOLDER = "function" ARGUMENT_FOLDER = "arguments" RESULT_FOLDER = "results" -PIPELINE_BUILD_TIME = sagemaker_timestamp() +PIPELINE_BUILD_TIME = "2022-05-10T17:30:20Z" mock_s3 = {} diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index 96f6998af6..2b59113354 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -202,7 +202,7 @@ def test_unsupported_cpu_instance( ).fit() -@pytest.mark.parametrize("unsupported_gpu_instance_class", UNSUPPORTED_GPU_INSTANCE_CLASSES) +@pytest.mark.parametrize("unsupported_gpu_instance_class", sorted(UNSUPPORTED_GPU_INSTANCE_CLASSES)) def test_unsupported_gpu_instance( unsupported_gpu_instance_class, huggingface_training_compiler_version, @@ -366,7 +366,7 @@ def test_unsupported_distribution( @patch("sagemaker.utils.create_tar_file", MagicMock()) @patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) @patch("time.time", return_value=TIME) -@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES) +@pytest.mark.parametrize("instance_class", sorted(SUPPORTED_GPU_INSTANCE_CLASSES)) def test_pytorchxla_distribution( time, name_from_base, @@ -430,7 +430,7 @@ def test_pytorchxla_distribution( @patch("sagemaker.utils.create_tar_file", MagicMock()) @patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) @patch("time.time", return_value=TIME) -@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES) +@pytest.mark.parametrize("instance_class", sorted(SUPPORTED_GPU_INSTANCE_CLASSES)) def test_default_compiler_config( time, name_from_base, diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py index a650379dfd..dfe4d10c3a 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py @@ -201,7 +201,7 @@ def test_unsupported_cpu_instance( ).fit() -@pytest.mark.parametrize("unsupported_gpu_instance_class", UNSUPPORTED_GPU_INSTANCE_CLASSES) +@pytest.mark.parametrize("unsupported_gpu_instance_class", sorted(UNSUPPORTED_GPU_INSTANCE_CLASSES)) def test_unsupported_gpu_instance( unsupported_gpu_instance_class, huggingface_training_compiler_version, @@ -315,7 +315,7 @@ def test_unsupported_distribution( @patch("sagemaker.utils.create_tar_file", MagicMock()) @patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) @patch("time.time", return_value=TIME) -@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES) +@pytest.mark.parametrize("instance_class", sorted(SUPPORTED_GPU_INSTANCE_CLASSES)) def test_default_compiler_config( time, name_from_base, diff --git a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py index 9a7ba698f3..7417e006a1 100644 --- a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py @@ -191,7 +191,7 @@ def test_unsupported_cpu_instance(cpu_instance_type, pytorch_training_compiler_v ).fit() -@pytest.mark.parametrize("unsupported_gpu_instance_class", UNSUPPORTED_GPU_INSTANCE_CLASSES) +@pytest.mark.parametrize("unsupported_gpu_instance_class", sorted(UNSUPPORTED_GPU_INSTANCE_CLASSES)) def test_unsupported_gpu_instance( unsupported_gpu_instance_class, pytorch_training_compiler_version ): @@ -309,7 +309,7 @@ def test_unsupported_distribution( @patch("sagemaker.utils.create_tar_file", MagicMock()) @patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) @patch("time.time", return_value=TIME) -@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES) +@pytest.mark.parametrize("instance_class", sorted(SUPPORTED_GPU_INSTANCE_CLASSES)) def test_pytorchxla_distribution( time, name_from_base, @@ -372,7 +372,7 @@ def test_pytorchxla_distribution( @patch("sagemaker.utils.create_tar_file", MagicMock()) @patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) @patch("time.time", return_value=TIME) -@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES) +@pytest.mark.parametrize("instance_class", sorted(SUPPORTED_GPU_INSTANCE_CLASSES)) def test_default_compiler_config( time, name_from_base, diff --git a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py index 67530bc288..ebad1366ee 100644 --- a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py @@ -176,7 +176,9 @@ def test_cpu_instance( compiler_config=TrainingCompilerConfig(), ).fit() - @pytest.mark.parametrize("unsupported_gpu_instance_class", UNSUPPORTED_GPU_INSTANCE_CLASSES) + @pytest.mark.parametrize( + "unsupported_gpu_instance_class", sorted(UNSUPPORTED_GPU_INSTANCE_CLASSES) + ) def test_gpu_instance( self, unsupported_gpu_instance_class, @@ -254,7 +256,7 @@ def test_python_2(self, tensorflow_training_version): @patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) @patch("time.time", return_value=TIME) class TestTrainingCompilerConfig: - @pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES) + @pytest.mark.parametrize("instance_class", sorted(SUPPORTED_GPU_INSTANCE_CLASSES)) def test_default( self, time, @@ -308,7 +310,7 @@ def test_default( actual_train_args == expected_train_args ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}" - @pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES) + @pytest.mark.parametrize("instance_class", sorted(SUPPORTED_GPU_INSTANCE_CLASSES)) def test_byoc( self, time, diff --git a/tests/unit/sagemaker/workflow/test_training_step.py b/tests/unit/sagemaker/workflow/test_training_step.py index 673e54dbbe..f31eb07d85 100644 --- a/tests/unit/sagemaker/workflow/test_training_step.py +++ b/tests/unit/sagemaker/workflow/test_training_step.py @@ -263,6 +263,8 @@ Join(on="/", values=["s3://my-bucket", "my-input"]), ] +OUTPUT_PARAM_LIST = ["s3://my-bucket/my-output-path", ParameterString(name="OutputPath")] + @pytest.fixture def training_input(): @@ -454,6 +456,7 @@ def test_training_step_estimator_with_param_code_input( assert step_def == step_def2 +@pytest.mark.skip(reason="incompatible with pytest-xdist") @pytest.mark.parametrize("estimator", ESTIMATOR_LISTS) @pytest.mark.parametrize("training_input", INPUT_PARAM_LISTS) @pytest.mark.parametrize( @@ -523,6 +526,7 @@ def test_training_step_with_framework_estimator( assert step_def == step_def2 +@pytest.mark.skip(reason="incompatible with pytest-xdist") @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) @pytest.mark.parametrize("estimator", ESTIMATOR_LISTS_LOCAL_CODE) @pytest.mark.parametrize("training_input", INPUT_PARAM_LISTS) From 554d7209cc89935372971f29da614d84330aeaea Mon Sep 17 00:00:00 2001 From: xiongz945 <54782408+xiongz945@users.noreply.github.com> Date: Thu, 7 Mar 2024 18:31:44 -0800 Subject: [PATCH 25/42] feature: Add overriding logic in ModelBuilder when task is provided (#4460) * feat: Add Optional task to Model * Revert "feat: Add Optional task to Model" This reverts commit fd3e86b19f091dc1ce4d7906644107e08768c6a4. * Add override logic in ModelBuilder with task provided * Adjusted formatting * Add extra unit tests for invalid inputs * Address PR comments * Add more test inputs to integration test * Add model_metadata field to ModelBuilder * Update doc * Update doc * Adjust formatting --------- Co-authored-by: Samrudhi Sharma Co-authored-by: Xiong Zeng --- src/sagemaker/huggingface/llm_utils.py | 3 +- src/sagemaker/serve/builder/model_builder.py | 19 ++- src/sagemaker/serve/schema/task.json | 40 +++--- .../sagemaker/serve/test_schema_builder.py | 66 +++++++++ .../serve/builder/test_model_builder.py | 125 +++++++++++++++++- 5 files changed, 226 insertions(+), 27 deletions(-) diff --git a/src/sagemaker/huggingface/llm_utils.py b/src/sagemaker/huggingface/llm_utils.py index 1a2abfb2e4..de5e624dbc 100644 --- a/src/sagemaker/huggingface/llm_utils.py +++ b/src/sagemaker/huggingface/llm_utils.py @@ -81,7 +81,8 @@ def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] = Returns: dict: The model metadata retrieved with the HuggingFace API """ - + if not model_id: + raise ValueError("Model ID is empty. Please provide a valid Model ID.") hf_model_metadata_url = f"https://huggingface.co/api/models/{model_id}" hf_model_metadata_json = None try: diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index c66057397f..4d1e51cb26 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -124,8 +124,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): into a stream. All translations between the server and the client are handled automatically with the specified input and output. model (Optional[Union[object, str]): Model object (with ``predict`` method to perform - inference) or a HuggingFace/JumpStart Model ID. Either ``model`` or - ``inference_spec`` is required for the model builder to build the artifact. + inference) or a HuggingFace/JumpStart Model ID. Either ``model`` or ``inference_spec`` + is required for the model builder to build the artifact. inference_spec (InferenceSpec): The inference spec file with your customized ``invoke`` and ``load`` functions. image_uri (Optional[str]): The container image uri (which is derived from a @@ -145,6 +145,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): to the model server). Possible values for this argument are ``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``, ``TRITON``, and``TGI``. + model_metadata (Optional[Dict[str, Any]): Dictionary used to override the HuggingFace + model metadata. Currently ``HF_TASK`` is overridable. """ model_path: Optional[str] = field( @@ -241,6 +243,10 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): model_server: Optional[ModelServer] = field( default=None, metadata={"help": "Define the model server to deploy to."} ) + model_metadata: Optional[Dict[str, Any]] = field( + default=None, + metadata={"help": "Define the model metadata to override, currently supports `HF_TASK`"}, + ) def _build_validations(self): """Placeholder docstring""" @@ -616,6 +622,9 @@ def build( # pylint: disable=R0911 self._is_custom_image_uri = self.image_uri is not None if isinstance(self.model, str): + model_task = None + if self.model_metadata: + model_task = self.model_metadata.get("HF_TASK") if self._is_jumpstart_model_id(): return self._build_for_jumpstart() if self._is_djl(): # pylint: disable=R1705 @@ -625,10 +634,10 @@ def build( # pylint: disable=R0911 self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - model_task = hf_model_md.get("pipeline_tag") - if self.schema_builder is None and model_task: + if model_task is None: + model_task = hf_model_md.get("pipeline_tag") + if self.schema_builder is None and model_task is not None: self._schema_builder_init(model_task) - if model_task == "text-generation": # pylint: disable=R1705 return self._build_for_tgi() elif self._can_fit_on_single_gpu(): diff --git a/src/sagemaker/serve/schema/task.json b/src/sagemaker/serve/schema/task.json index 1a7bdce5d0..c897f4abec 100644 --- a/src/sagemaker/serve/schema/task.json +++ b/src/sagemaker/serve/schema/task.json @@ -1,12 +1,12 @@ { "fill-mask": { - "sample_inputs": { + "sample_inputs": { "properties": { "inputs": "Paris is the [MASK] of France.", "parameters": {} } - }, - "sample_outputs": { + }, + "sample_outputs": { "properties": [ { "sequence": "Paris is the capital of France.", @@ -14,15 +14,15 @@ } ] } - }, + }, "question-answering": { - "sample_inputs": { + "sample_inputs": { "properties": { "context": "I have a German Shepherd dog, named Coco.", "question": "What is my dog's breed?" } - }, - "sample_outputs": { + }, + "sample_outputs": { "properties": [ { "answer": "German Shepherd", @@ -32,15 +32,15 @@ } ] } - }, + }, "text-classification": { - "sample_inputs": { + "sample_inputs": { "properties": { "inputs": "Where is the capital of France?, Paris is the capital of France.", "parameters": {} } - }, - "sample_outputs": { + }, + "sample_outputs": { "properties": [ { "label": "entailment", @@ -48,20 +48,20 @@ } ] } - }, - "text-generation": { - "sample_inputs": { + }, + "text-generation": { + "sample_inputs": { "properties": { "inputs": "Hello, I'm a language model", "parameters": {} } - }, - "sample_outputs": { + }, + "sample_outputs": { "properties": [ - { - "generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my" - } + { + "generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my" + } ] } - } + } } diff --git a/tests/integ/sagemaker/serve/test_schema_builder.py b/tests/integ/sagemaker/serve/test_schema_builder.py index 3816985d8f..2b6ac48460 100644 --- a/tests/integ/sagemaker/serve/test_schema_builder.py +++ b/tests/integ/sagemaker/serve/test_schema_builder.py @@ -99,3 +99,69 @@ def test_model_builder_negative_path(sagemaker_session): match="Error Message: Schema builder for text-to-image could not be found.", ): model_builder.build(sagemaker_session=sagemaker_session) + + +@pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="Testing Schema Builder Simplification feature", +) +@pytest.mark.parametrize( + "model_id, task_provided", + [ + ("bert-base-uncased", "fill-mask"), + ("bert-large-uncased-whole-word-masking-finetuned-squad", "question-answering"), + ], +) +def test_model_builder_happy_path_with_task_provided( + model_id, task_provided, sagemaker_session, gpu_instance_type +): + model_builder = ModelBuilder(model=model_id, model_metadata={"HF_TASK": task_provided}) + + model = model_builder.build(sagemaker_session=sagemaker_session) + + assert model is not None + assert model_builder.schema_builder is not None + + inputs, outputs = task.retrieve_local_schemas(task_provided) + assert model_builder.schema_builder.sample_input == inputs + assert model_builder.schema_builder.sample_output == outputs + + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + caught_ex = None + try: + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] + + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") + predictor = model.deploy( + role=role_arn, instance_count=1, instance_type=gpu_instance_type + ) + + predicted_outputs = predictor.predict(inputs) + assert predicted_outputs is not None + + except Exception as e: + caught_ex = e + finally: + cleanup_model_resources( + sagemaker_session=model_builder.sagemaker_session, + model_name=model.name, + endpoint_name=model.endpoint_name, + ) + if caught_ex: + logger.exception(caught_ex) + assert ( + False + ), f"{caught_ex} was thrown when running transformers sagemaker endpoint test" + + +def test_model_builder_negative_path_with_invalid_task(sagemaker_session): + model_builder = ModelBuilder( + model="bert-base-uncased", model_metadata={"HF_TASK": "invalid-task"} + ) + + with pytest.raises( + TaskNotFoundException, + match="Error Message: Schema builder for invalid-task could not be found.", + ): + model_builder.build(sagemaker_session=sagemaker_session) diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 1f743ff442..3b60d13dfb 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -1076,7 +1076,7 @@ def test_build_negative_path_when_schema_builder_not_present( model_builder = ModelBuilder(model="CompVis/stable-diffusion-v1-4") - self.assertRaisesRegexp( + self.assertRaisesRegex( TaskNotFoundException, "Error Message: Schema builder for text-to-image could not be found.", lambda: model_builder.build(sagemaker_session=mock_session), @@ -1593,3 +1593,126 @@ def test_total_inference_model_size_mib_throws( model_builder.build(sagemaker_session=mock_session) self.assertEqual(model_builder._can_fit_on_single_gpu(), False) + + @patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_happy_path_override_with_task_provided( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + mock_hf_model, + ): + # Setup mocks + + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + # HF Pipeline Tag + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "fill-mask"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + # HF Model config + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + + model_builder = ModelBuilder( + model="bert-base-uncased", model_metadata={"HF_TASK": "text-generation"} + ) + model_builder.build(sagemaker_session=mock_session) + + self.assertIsNotNone(model_builder.schema_builder) + sample_inputs, sample_outputs = task.retrieve_local_schemas("text-generation") + self.assertEqual( + sample_inputs["inputs"], model_builder.schema_builder.sample_input["inputs"] + ) + self.assertEqual(sample_outputs, model_builder.schema_builder.sample_output) + + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_task_override_with_invalid_task_provided( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + ): + # Setup mocks + + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + # HF Pipeline Tag + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "fill-mask"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + # HF Model config + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + model_ids_with_invalid_task = { + "bert-base-uncased": "invalid-task", + "bert-large-uncased-whole-word-masking-finetuned-squad": "", + } + for model_id in model_ids_with_invalid_task: + provided_task = model_ids_with_invalid_task[model_id] + model_builder = ModelBuilder(model=model_id, model_metadata={"HF_TASK": provided_task}) + + self.assertRaisesRegex( + TaskNotFoundException, + f"Error Message: Schema builder for {provided_task} could not be found.", + lambda: model_builder.build(sagemaker_session=mock_session), + ) + + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_task_override_with_invalid_model_provided( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_image_uris_retrieve, + ): + # Setup mocks + + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + # HF Pipeline Tag + mock_model_uris_retrieve.side_effect = KeyError + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + invalid_model_id = "" + provided_task = "fill-mask" + + model_builder = ModelBuilder( + model=invalid_model_id, model_metadata={"HF_TASK": provided_task} + ) + with self.assertRaises(Exception): + model_builder.build(sagemaker_session=mock_session) From 28a1665a5a3f29b1d080d4df2f4eeecb1c326ed0 Mon Sep 17 00:00:00 2001 From: martinRenou Date: Fri, 8 Mar 2024 19:19:46 +0100 Subject: [PATCH 26/42] feature: Accept user-defined env variables for the entry-point (#4175) --- src/sagemaker/model.py | 8 ++--- .../test_huggingface_pytorch_compiler.py | 29 +++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index ff340b58e9..5a2b27c54d 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -766,8 +766,8 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None: def _script_mode_env_vars(self): """Returns a mapping of environment variables for script mode execution""" - script_name = None - dir_name = None + script_name = self.env.get(SCRIPT_PARAM_NAME.upper(), "") + dir_name = self.env.get(DIR_PARAM_NAME.upper(), "") if self.uploaded_code: script_name = self.uploaded_code.script_name if self.repacked_model_data or self.enable_network_isolation(): @@ -783,8 +783,8 @@ def _script_mode_env_vars(self): else "file://" + self.source_dir ) return { - SCRIPT_PARAM_NAME.upper(): script_name or str(), - DIR_PARAM_NAME.upper(): dir_name or str(), + SCRIPT_PARAM_NAME.upper(): script_name, + DIR_PARAM_NAME.upper(): dir_name, CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): to_string(self.container_log_level), SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name, } diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index 2b59113354..12162f799f 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -718,3 +718,32 @@ def test_register_hf_pytorch_model_auto_infer_framework( sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request ) + + +def test_accept_user_defined_environment_variables( + sagemaker_session, + huggingface_training_compiler_version, + huggingface_training_compiler_pytorch_version, + huggingface_training_compiler_pytorch_py_version, +): + program = "inference.py" + directory = "/opt/ml/model/code" + + hf_model = HuggingFaceModel( + model_data="s3://some/data.tar.gz", + role=ROLE, + transformers_version=huggingface_training_compiler_version, + pytorch_version=huggingface_training_compiler_pytorch_version, + py_version=huggingface_training_compiler_pytorch_py_version, + sagemaker_session=sagemaker_session, + env={ + "SAGEMAKER_PROGRAM": program, + "SAGEMAKER_SUBMIT_DIRECTORY": directory, + }, + image_uri="fakeimage", + ) + + container_env = hf_model.prepare_container_def("ml.m4.xlarge")["Environment"] + + assert container_env["SAGEMAKER_PROGRAM"] == program + assert container_env["SAGEMAKER_SUBMIT_DIRECTORY"] == directory From c8f78f798b8e41886dba0277bcc1e8e83d50fcfa Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Mon, 11 Mar 2024 16:42:55 -0700 Subject: [PATCH 27/42] fix: Move sagemaker pysdk version check after bootstrap in remote job (#4487) --- .../runtime_environment/bootstrap_runtime_environment.py | 7 ++++--- .../runtime_environment/runtime_environment_manager.py | 9 ++++----- .../test_bootstrap_runtime_environment.py | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py index d5d879cb08..8fd83bfcfe 100644 --- a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -65,9 +65,6 @@ def main(sys_args=None): conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env) - RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( - client_sagemaker_pysdk_version - ) user = getpass.getuser() if user != "root": @@ -89,6 +86,10 @@ def main(sys_args=None): client_python_version, conda_env, dependency_settings ) + RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( + client_sagemaker_pysdk_version + ) + exit_code = SUCCESS_EXIT_CODE except Exception as e: # pylint: disable=broad-except logger.exception("Error encountered while bootstrapping runtime environment: %s", e) diff --git a/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py b/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py index 0dd5f0d219..13493c1d15 100644 --- a/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py +++ b/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py @@ -24,8 +24,6 @@ import dataclasses import json -import sagemaker - class _UTCFormatter(logging.Formatter): """Class that overrides the default local time provider in log formatter.""" @@ -330,6 +328,7 @@ def _current_python_version(self): def _current_sagemaker_pysdk_version(self): """Returns the current sagemaker python sdk version where program is running""" + import sagemaker return sagemaker.__version__ @@ -366,10 +365,10 @@ def _validate_sagemaker_pysdk_version(self, client_sagemaker_pysdk_version): ): logger.warning( "Inconsistent sagemaker versions found: " - "sagemaker pysdk version found in the container is " + "sagemaker python sdk version found in the container is " "'%s' which does not match the '%s' on the local client. " - "Please make sure that the python version used in the training container " - "is the same as the local python version in case of unexpected behaviors.", + "Please make sure that the sagemaker version used in the training container " + "is the same as the local sagemaker version in case of unexpected behaviors.", job_sagemaker_pysdk_version, client_sagemaker_pysdk_version, ) diff --git a/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py b/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py index b7d9e10047..ef35c965e9 100644 --- a/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py +++ b/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py @@ -269,7 +269,7 @@ def test_main_failure_remote_job_with_root_user( change_dir_permission.assert_not_called() validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV) - validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION) + validate_sagemaker.assert_not_called() run_pre_exec_script.assert_not_called() bootstrap_runtime.assert_called() write_failure.assert_called_with(str(runtime_err)) @@ -317,7 +317,7 @@ def test_main_failure_pipeline_step_with_root_user( change_dir_permission.assert_not_called() validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV) - validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION) + validate_sagemaker.assert_not_called() run_pre_exec_script.assert_not_called() bootstrap_runtime.assert_called() write_failure.assert_called_with(str(runtime_err)) From f9fb1b912b23a18008d98fdda54cb4f3111ce871 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> Date: Tue, 12 Mar 2024 08:17:36 -0700 Subject: [PATCH 28/42] change: enable github actions for PRs (#4489) * change: enable github actions for PRs * Update codebuild-ci.yml * trigger on pull_request_target * add source-version-override * fix permission --- .github/workflows/codebuild-ci.yml | 48 ++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 .github/workflows/codebuild-ci.yml diff --git a/.github/workflows/codebuild-ci.yml b/.github/workflows/codebuild-ci.yml new file mode 100644 index 0000000000..e72680be2a --- /dev/null +++ b/.github/workflows/codebuild-ci.yml @@ -0,0 +1,48 @@ +name: PR Checks +on: + pull_request_target: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.head_ref }} + cancel-in-progress: true + +permissions: + id-token: write # This is required for requesting the JWT + +jobs: + codestyle-doc-tests: + runs-on: ubuntu-latest + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} + aws-region: us-west-2 + role-duration-seconds: 10800 + - name: Run Codestyle & Doc Tests + uses: aws-actions/aws-codebuild-run-build@v1 + with: + project-name: sagemaker-python-sdk-ci-codestyle-doc-tests + source-version-override: 'pr/${{ github.event.pull_request.number }}' + unit-tests: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["py38", "py39", "py310"] + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} + aws-region: us-west-2 + role-duration-seconds: 10800 + - name: Run Unit Tests + uses: aws-actions/aws-codebuild-run-build@v1 + with: + project-name: sagemaker-python-sdk-ci-unit-tests + source-version-override: 'pr/${{ github.event.pull_request.number }}' + env-vars-for-codebuild: | + PY_VERSION + env: + PY_VERSION: ${{ matrix.python-version }} From e95ed65e803730abeedf1577766264fc0fdb3193 Mon Sep 17 00:00:00 2001 From: mrudulmn <161017394+mrudulmn@users.noreply.github.com> Date: Wed, 13 Mar 2024 01:27:05 +0530 Subject: [PATCH 29/42] feature: Add ModelDataSource and SourceUri support for model package and while registering (#4492) Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> --- src/sagemaker/chainer/model.py | 4 + src/sagemaker/estimator.py | 3 + src/sagemaker/huggingface/model.py | 4 + src/sagemaker/jumpstart/factory/model.py | 2 + src/sagemaker/jumpstart/model.py | 4 + src/sagemaker/jumpstart/types.py | 3 + src/sagemaker/model.py | 98 +++++++- src/sagemaker/mxnet/model.py | 4 + src/sagemaker/pipeline.py | 4 + src/sagemaker/pytorch/model.py | 4 + src/sagemaker/session.py | 123 +++++++++- src/sagemaker/sklearn/model.py | 4 + src/sagemaker/tensorflow/model.py | 4 + src/sagemaker/utils.py | 18 ++ src/sagemaker/workflow/_utils.py | 4 + src/sagemaker/workflow/step_collections.py | 3 + src/sagemaker/xgboost/model.py | 4 + tests/integ/test_model_package.py | 206 ++++++++++++++++ tests/unit/sagemaker/model/test_model.py | 88 ++++++- .../sagemaker/model/test_model_package.py | 75 ++++-- tests/unit/test_session.py | 227 ++++++++++++++++++ tests/unit/test_utils.py | 13 + 22 files changed, 864 insertions(+), 35 deletions(-) diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index bafcfde3a8..9fce051454 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -174,6 +174,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -223,6 +224,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: str: A string of SageMaker Model Package ARN. @@ -262,6 +265,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def prepare_container_def( diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 7ef367e485..501c826f82 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1718,6 +1718,7 @@ def register( nearest_model_name=None, data_input_configuration=None, skip_model_validation=None, + source_uri=None, **kwargs, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -1765,6 +1766,7 @@ def register( data_input_configuration (str): Input object for the model (default: None). skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str): The URI of the source for the model package (default: None). **kwargs: Passed to invocation of ``create_model()``. Implementations may customize ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. For more, see the implementation docs. @@ -1809,6 +1811,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) @property diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index efe6a85288..f71dca0ac8 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -360,6 +360,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -410,6 +411,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -457,6 +460,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def prepare_container_def( diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 9448e45cc2..99ae05fc44 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -623,6 +623,7 @@ def get_register_kwargs( nearest_model_name: Optional[str] = None, data_input_configuration: Optional[str] = None, skip_model_validation: Optional[str] = None, + source_uri: Optional[str] = None, ) -> JumpStartModelRegisterKwargs: """Returns kwargs required to call `register` on `sagemaker.estimator.Model` object.""" @@ -655,6 +656,7 @@ def get_register_kwargs( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) model_specs = verify_model_region_and_return_specs( diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index c0da00ac56..f96181479a 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -643,6 +643,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -688,6 +689,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -722,6 +725,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict()) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 8c74e007ae..8bdb3eb57b 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1908,6 +1908,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "nearest_model_name", "data_input_configuration", "skip_model_validation", + "source_uri", ] SERIALIZATION_EXCLUSION_SET = { @@ -1950,6 +1951,7 @@ def __init__( nearest_model_name: Optional[str] = None, data_input_configuration: Optional[str] = None, skip_model_validation: Optional[str] = None, + source_uri: Optional[str] = None, ) -> None: """Instantiates JumpStartModelRegisterKwargs object.""" @@ -1982,3 +1984,4 @@ def __init__( self.nearest_model_name = nearest_model_name self.data_input_configuration = data_input_configuration self.skip_model_validation = skip_model_validation + self.source_uri = source_uri diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 5a2b27c54d..af08d1203f 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -77,7 +77,10 @@ ) from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker.enums import EndpointType -from sagemaker.session import get_add_model_package_inference_args +from sagemaker.session import ( + get_add_model_package_inference_args, + get_update_model_package_inference_args, +) # Setting LOGGER for backward compatibility, in case users import it... logger = LOGGER = logging.getLogger("sagemaker") @@ -423,6 +426,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -472,17 +476,14 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance or pipeline step arguments in case the Model instance is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession` """ - if isinstance(self.model_data, dict): - raise ValueError( - "SageMaker Model Package currently cannot be created with ModelDataSource." - ) - if content_types is not None: self.content_types = content_types @@ -513,6 +514,12 @@ def register( "Image": self.image_uri, } + if isinstance(self.model_data, dict): + raise ValueError( + "Un-versioned SageMaker Model Package currently cannot be " + "created with ModelDataSource." + ) + if self.model_data is not None: container_def["ModelDataUrl"] = self.model_data @@ -536,6 +543,7 @@ def register( sample_payload_url=sample_payload_url, task=task, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) model_package = self.sagemaker_session.create_model_package_from_containers( **model_pkg_args @@ -2040,8 +2048,9 @@ def __init__( endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource. - model_data (str): The S3 location of a SageMaker model data - ``.tar.gz`` file. Must be provided if algorithm_arn is provided. + model_data (str or dict[str, Any]): The S3 location of a SageMaker model data + ``.tar.gz`` file or a dictionary representing a ``ModelDataSource`` + object. Must be provided if algorithm_arn is provided. algorithm_arn (str): algorithm arn used to train the model, can be just the name if your account owns the algorithm. Must also provide ``model_data``. @@ -2050,11 +2059,6 @@ def __init__( ``model_data`` is not required. **kwargs: Additional kwargs passed to the Model constructor. """ - if isinstance(model_data, dict): - raise ValueError( - "Creating ModelPackage with ModelDataSource is currently not supported" - ) - super(ModelPackage, self).__init__( role=role, model_data=model_data, image_uri=None, **kwargs ) @@ -2222,6 +2226,74 @@ def update_customer_metadata(self, customer_metadata_properties: Dict[str, str]) sagemaker_session = self.sagemaker_session or sagemaker.Session() sagemaker_session.sagemaker_client.update_model_package(**update_metadata_args) + def update_inference_specification( + self, + containers: Dict = None, + image_uris: List[str] = None, + content_types: List[str] = None, + response_types: List[str] = None, + inference_instances: List[str] = None, + transform_instances: List[str] = None, + ): + """Inference specification to be set for the model package + + Args: + containers (dict): The Amazon ECR registry path of the Docker image + that contains the inference code. + image_uris (List[str]): The ECR path where inference code is stored. + content_types (list[str]): The supported MIME types + for the input data. + response_types (list[str]): The supported MIME types + for the output data. + inference_instances (list[str]): A list of the instance + types that are used to generate inferences in real-time (default: None). + transform_instances (list[str]): A list of the instance + types on which a transformation job can be run or on which an endpoint can be + deployed (default: None). + + """ + sagemaker_session = self.sagemaker_session or sagemaker.Session() + if (containers is not None) ^ (image_uris is None): + raise ValueError("Should have either containers or image_uris for inference.") + container_def = [] + if image_uris: + for uri in image_uris: + container_def.append( + { + "Image": uri, + } + ) + else: + container_def = containers + + model_package_update_args = get_update_model_package_inference_args( + model_package_arn=self.model_package_arn, + containers=container_def, + content_types=content_types, + response_types=response_types, + inference_instances=inference_instances, + transform_instances=transform_instances, + ) + + sagemaker_session.sagemaker_client.update_model_package(**model_package_update_args) + + def update_source_uri( + self, + source_uri: str, + ): + """Source uri to be set for the model package + + Args: + source_uri (str): The URI of the source for the model package. + + """ + update_source_uri_args = { + "ModelPackageArn": self.model_package_arn, + "SourceUri": source_uri, + } + sagemaker_session = self.sagemaker_session or sagemaker.Session() + sagemaker_session.sagemaker_client.update_model_package(**update_source_uri_args) + def remove_customer_metadata_properties( self, customer_metadata_properties_to_remove: List[str] ): diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 8cd0ac6b65..714b0db945 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -176,6 +176,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -225,6 +226,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -264,6 +267,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def prepare_container_def( diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index a4b7feac69..3bfdb1a594 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -360,6 +360,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -409,6 +410,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: If ``sagemaker_session`` is a ``PipelineSession`` instance, returns pipeline step @@ -456,6 +459,7 @@ def register( sample_payload_url=sample_payload_url, task=task, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) self.sagemaker_session.create_model_package_from_containers(**model_pkg_args) diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index fb731cabf4..f490e49375 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -178,6 +178,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -227,6 +228,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -266,6 +269,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def prepare_container_def( diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 5f3fa5e5a0..d0b4448520 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -140,6 +140,7 @@ ) from sagemaker import exceptions from sagemaker.session_settings import SessionSettings +from sagemaker.utils import can_model_package_source_uri_autopopulate # Setting LOGGER for backward compatibility, in case users import it... logger = LOGGER = logging.getLogger("sagemaker") @@ -3969,14 +3970,19 @@ def create_model_package_from_algorithm(self, name, description, algorithm_arn, name (str): ModelPackage name description (str): Model Package description algorithm_arn (str): arn or name of the algorithm used for training. - model_data (str): s3 URI to the model artifacts produced by training + model_data (str or dict[str, Any]): s3 URI or a dictionary representing a + ``ModelDataSource`` to the model artifacts produced by training """ + sourceAlgorithm = {"AlgorithmName": algorithm_arn} + if isinstance(model_data, dict): + sourceAlgorithm["ModelDataSource"] = model_data + else: + sourceAlgorithm["ModelDataUrl"] = model_data + request = { "ModelPackageName": name, "ModelPackageDescription": description, - "SourceAlgorithmSpecification": { - "SourceAlgorithms": [{"AlgorithmName": algorithm_arn, "ModelDataUrl": model_data}] - }, + "SourceAlgorithmSpecification": {"SourceAlgorithms": [sourceAlgorithm]}, } try: logger.info("Creating model package with name: %s", name) @@ -4011,6 +4017,7 @@ def create_model_package_from_containers( sample_payload_url=None, task=None, skip_model_validation="None", + source_uri=None, ): """Get request dictionary for CreateModelPackage API. @@ -4047,6 +4054,7 @@ def create_model_package_from_containers( "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str): The URI of the source for the model package (default: None). """ if containers: # Containers are provided. Now we can merge missing entries from config. @@ -4103,6 +4111,7 @@ def create_model_package_from_containers( sample_payload_url=sample_payload_url, task=task, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def submit(request): @@ -4114,6 +4123,26 @@ def submit(request): ModelPackageGroupName=request["ModelPackageGroupName"] ) ) + if "SourceUri" in request and request["SourceUri"] is not None: + # Remove inference spec from request if the + # given source uri can lead to auto-population of it + if can_model_package_source_uri_autopopulate(request["SourceUri"]): + if "InferenceSpecification" in request: + del request["InferenceSpecification"] + return self.sagemaker_client.create_model_package(**request) + # If source uri can't autopopulate, + # first create model package with just the inference spec + # and then update model package with the source uri. + # Done this way because passing source uri and inference spec together + # in create/update model package is not allowed in the base sdk. + request_source_uri = request["SourceUri"] + del request["SourceUri"] + model_package = self.sagemaker_client.create_model_package(**request) + update_source_uri_args = { + "ModelPackageArn": model_package.get("ModelPackageArn"), + "SourceUri": request_source_uri, + } + return self.sagemaker_client.update_model_package(**update_source_uri_args) return self.sagemaker_client.create_model_package(**request) return self._intercept_create_request( @@ -6932,6 +6961,7 @@ def get_model_package_args( sample_payload_url=None, task=None, skip_model_validation=None, + source_uri=None, ): """Get arguments for create_model_package method. @@ -6970,6 +7000,7 @@ def get_model_package_args( "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str): The URI of the source for the model package (default: None). Returns: dict: A dictionary of method argument names and values. @@ -7024,6 +7055,8 @@ def get_model_package_args( model_package_args["task"] = task if skip_model_validation is not None: model_package_args["skip_model_validation"] = skip_model_validation + if source_uri is not None: + model_package_args["source_uri"] = source_uri return model_package_args @@ -7048,6 +7081,7 @@ def get_create_model_package_request( sample_payload_url=None, task=None, skip_model_validation="None", + source_uri=None, ): """Get request dictionary for CreateModelPackage API. @@ -7084,12 +7118,32 @@ def get_create_model_package_request( "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str): The URI of the source for the model package (default: None). """ if all([model_package_name, model_package_group_name]): raise ValueError( "model_package_name and model_package_group_name cannot be present at the " "same time." ) + if all([model_package_name, source_uri]): + raise ValueError( + "Un-versioned SageMaker Model Package currently cannot be " "created with source_uri." + ) + if (containers is not None) and all( + [ + model_package_name, + any( + [ + (("ModelDataSource" in c) and (c["ModelDataSource"] is not None)) + for c in containers + ] + ), + ] + ): + raise ValueError( + "Un-versioned SageMaker Model Package currently cannot be " + "created with ModelDataSource." + ) request_dict = {} if model_package_name is not None: request_dict["ModelPackageName"] = model_package_name @@ -7115,6 +7169,8 @@ def get_create_model_package_request( request_dict["SamplePayloadUrl"] = sample_payload_url if task is not None: request_dict["Task"] = task + if source_uri is not None: + request_dict["SourceUri"] = source_uri if containers is not None: inference_specification = { "Containers": containers, @@ -7163,6 +7219,65 @@ def get_create_model_package_request( return request_dict +def get_update_model_package_inference_args( + model_package_arn, + containers=None, + content_types=None, + response_types=None, + inference_instances=None, + transform_instances=None, +): + """Get request dictionary for UpdateModelPackage API for inference specification. + + Args: + model_package_arn (str): Arn for the model package. + containers (dict): The Amazon ECR registry path of the Docker image + that contains the inference code. + content_types (list[str]): The supported MIME types + for the input data. + response_types (list[str]): The supported MIME types + for the output data. + inference_instances (list[str]): A list of the instance + types that are used to generate inferences in real-time (default: None). + transform_instances (list[str]): A list of the instance + types on which a transformation job can be run or on which an endpoint can be + deployed (default: None). + """ + + request_dict = {} + if containers is not None: + inference_specification = { + "Containers": containers, + } + if content_types is not None: + inference_specification.update( + { + "SupportedContentTypes": content_types, + } + ) + if response_types is not None: + inference_specification.update( + { + "SupportedResponseMIMETypes": response_types, + } + ) + if inference_instances is not None: + inference_specification.update( + { + "SupportedRealtimeInferenceInstanceTypes": inference_instances, + } + ) + if transform_instances is not None: + inference_specification.update( + { + "SupportedTransformInstanceTypes": transform_instances, + } + ) + request_dict["InferenceSpecification"] = inference_specification + request_dict.update({"ModelPackageArn": model_package_arn}) + return request_dict + + def get_add_model_package_inference_args( model_package_arn, name, diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 195a6a3a57..27833c1d9c 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -171,6 +171,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -220,6 +221,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -259,6 +262,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def prepare_container_def( diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 1b35afbe7c..77f162207c 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -233,6 +233,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -282,6 +283,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -321,6 +324,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def deploy( diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 115b8b258d..7896aac150 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -48,6 +48,10 @@ from sagemaker.workflow.entities import PipelineVariable ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$" +MODEL_PACKAGE_ARN_PATTERN = ( + r"arn:aws([a-z\-]*)?:sagemaker:([a-z0-9\-]*):([0-9]{12}):model-package/(.*)" +) +MODEL_ARN_PATTERN = r"arn:aws([a-z\-]*):sagemaker:([a-z0-9\-]*):([0-9]{12}):model/(.*)" MAX_BUCKET_PATHS_COUNT = 5 S3_PREFIX = "s3://" HTTP_PREFIX = "http://" @@ -1581,3 +1585,17 @@ def custom_extractall_tarfile(tar, extract_path): tar.extractall(path=extract_path, filter="data") else: tar.extractall(path=extract_path, members=_get_safe_members(tar)) + + +def can_model_package_source_uri_autopopulate(source_uri: str): + """Checks if the source_uri can lead to auto-population of information in the Model registry. + + Args: + source_uri (str): The source uri. + + Returns: + bool: True if the source_uri can lead to auto-population, False otherwise. + """ + return bool( + re.match(MODEL_PACKAGE_ARN_PATTERN, source_uri) or re.match(MODEL_ARN_PATTERN, source_uri) + ) diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 1fafa646bf..841cd68083 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -328,6 +328,7 @@ def __init__( sample_payload_url=None, task=None, skip_model_validation=None, + source_uri=None, **kwargs, ): """Constructor of a register model step. @@ -379,6 +380,7 @@ def __init__( "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str): The URI of the source for the model package (default: None). **kwargs: additional arguments to `create_model`. """ super(_RegisterModelStep, self).__init__( @@ -415,6 +417,7 @@ def __init__( self.kwargs = kwargs self.container_def_list = container_def_list self.skip_model_validation = skip_model_validation + self.source_uri = source_uri self._properties = Properties( step_name=name, step=self, shape_name="DescribeModelPackageOutput" @@ -489,6 +492,7 @@ def arguments(self) -> RequestType: sample_payload_url=self.sample_payload_url, task=self.task, skip_model_validation=self.skip_model_validation, + source_uri=self.source_uri, ) request_dict = get_create_model_package_request(**model_package_args) diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index d48bf7c307..0eedf4aa96 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -96,6 +96,7 @@ def __init__( nearest_model_name=None, data_input_configuration=None, skip_model_validation=None, + source_uri=None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. @@ -153,6 +154,7 @@ def __init__( data_input_configuration (str): Input object for the model (default: None). skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str): The URI of the source for the model package (default: None). **kwargs: additional arguments to `create_model`. """ @@ -291,6 +293,7 @@ def __init__( sample_payload_url=sample_payload_url, task=task, skip_model_validation=skip_model_validation, + source_uri=source_uri, **kwargs, ) if not repack_model: diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index 74776f8f72..8101f32721 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -159,6 +159,7 @@ def register( nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + source_uri: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -208,6 +209,8 @@ def register( (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). + source_uri (str or PipelineVariable): The URI of the source for the model package + (default: None). Returns: str: A string of SageMaker Model Package ARN. @@ -247,6 +250,7 @@ def register( nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, + source_uri=source_uri, ) def prepare_container_def( diff --git a/tests/integ/test_model_package.py b/tests/integ/test_model_package.py index 1554825fc2..914c5db7ed 100644 --- a/tests/integ/test_model_package.py +++ b/tests/integ/test_model_package.py @@ -18,6 +18,8 @@ from tests.integ import DATA_DIR from sagemaker.xgboost import XGBoostModel from sagemaker import image_uris +from sagemaker.session import get_execution_role +from sagemaker.model import ModelPackage _XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone") @@ -104,3 +106,207 @@ def test_inference_specification_addition(sagemaker_session): sagemaker_session.sagemaker_client.delete_model_package_group( ModelPackageGroupName=model_group_name ) + + +def test_update_inference_specification(sagemaker_session): + model_group_name = unique_name_from_base("test-model-group") + source_uri = "dummy source uri" + + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + + model_package = sagemaker_session.sagemaker_client.create_model_package( + ModelPackageGroupName=model_group_name, SourceUri=source_uri + ) + + mp = ModelPackage( + role=get_execution_role(sagemaker_session), + model_package_arn=model_package["ModelPackageArn"], + sagemaker_session=sagemaker_session, + ) + + xgb_image = image_uris.retrieve( + "xgboost", sagemaker_session.boto_region_name, version="1", image_scope="inference" + ) + + mp.update_inference_specification(image_uris=[xgb_image]) + + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package["ModelPackageArn"] + ) + + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=model_package["ModelPackageArn"] + ) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) + + assert len(desc_model_package["InferenceSpecification"]["Containers"]) == 1 + assert desc_model_package["InferenceSpecification"]["Containers"][0]["Image"] == xgb_image + + +def test_update_source_uri(sagemaker_session): + model_group_name = unique_name_from_base("test-model-group") + source_uri = "dummy source uri" + + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + + xgb_model_data_s3 = sagemaker_session.upload_data( + path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + model = XGBoostModel( + model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session + ) + + model_package = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.m5.large"], + transform_instances=["ml.m5.large"], + model_package_group_name=model_group_name, + ) + + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + model_package.update_source_uri(source_uri=source_uri) + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + assert desc_model_package["SourceUri"] == source_uri + + +def test_clone_model_package_using_source_uri(sagemaker_session): + model_group_name = unique_name_from_base("test-model-group") + + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + + xgb_model_data_s3 = sagemaker_session.upload_data( + path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + model = XGBoostModel( + model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session + ) + + model_package = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.m5.large"], + transform_instances=["ml.m5.large"], + model_package_group_name=model_group_name, + source_uri="dummy-source-uri", + ) + + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + model2 = XGBoostModel( + model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session + ) + cloned_model_package = model2.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.m5.large"], + transform_instances=["ml.m5.large"], + model_package_group_name=model_group_name, + source_uri=model_package.model_package_arn, + ) + + desc_cloned_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=cloned_model_package.model_package_arn + ) + + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=model_package.model_package_arn + ) + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=cloned_model_package.model_package_arn + ) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) + + assert len(desc_cloned_model_package["InferenceSpecification"]["Containers"]) == len( + desc_model_package["InferenceSpecification"]["Containers"] + ) + assert len( + desc_cloned_model_package["InferenceSpecification"]["SupportedTransformInstanceTypes"] + ) == len(desc_model_package["InferenceSpecification"]["SupportedTransformInstanceTypes"]) + assert len( + desc_cloned_model_package["InferenceSpecification"][ + "SupportedRealtimeInferenceInstanceTypes" + ] + ) == len( + desc_model_package["InferenceSpecification"]["SupportedRealtimeInferenceInstanceTypes"] + ) + assert len(desc_cloned_model_package["InferenceSpecification"]["SupportedContentTypes"]) == len( + desc_model_package["InferenceSpecification"]["SupportedContentTypes"] + ) + assert len( + desc_cloned_model_package["InferenceSpecification"]["SupportedResponseMIMETypes"] + ) == len(desc_model_package["InferenceSpecification"]["SupportedResponseMIMETypes"]) + assert desc_cloned_model_package["SourceUri"] == model_package.model_package_arn + + +def test_register_model_using_source_uri(sagemaker_session): + model_name = unique_name_from_base("test-model") + model_group_name = unique_name_from_base("test-model-group") + + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + + xgb_model_data_s3 = sagemaker_session.upload_data( + path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + + model = XGBoostModel( + model_data=xgb_model_data_s3, + framework_version="1.3-1", + sagemaker_session=sagemaker_session, + role=get_execution_role(sagemaker_session), + ) + + model.name = model_name + model.create() + desc_model = sagemaker_session.sagemaker_client.describe_model(ModelName=model_name) + + model = XGBoostModel( + model_data=xgb_model_data_s3, + framework_version="1.3-1", + sagemaker_session=sagemaker_session, + role=get_execution_role(sagemaker_session), + ) + registered_model_package = model.register( + inference_instances=["ml.m5.xlarge"], + model_package_group_name=model_group_name, + source_uri=desc_model["ModelArn"], + ) + + desc_registered_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=registered_model_package.model_package_arn + ) + + sagemaker_session.sagemaker_client.delete_model(ModelName=model_name) + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=registered_model_package.model_package_arn + ) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) + + assert desc_registered_model_package["SourceUri"] == desc_model["ModelArn"] + assert "InferenceSpecification" in desc_registered_model_package + assert desc_registered_model_package["InferenceSpecification"] is not None diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index de86fcf99a..c0b18a3eb3 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -1118,7 +1118,46 @@ def test_register_calls_model_package_args(get_model_package_args, sagemaker_ses get_model_package_args""" -def test_register_calls_model_data_source_not_supported(sagemaker_session): +@patch("sagemaker.get_model_package_args") +def test_register_passes_source_uri_to_model_package_args( + get_model_package_args, sagemaker_session +): + source_dir = "s3://blah/blah/blah" + source_uri = "dummy_source_uri" + t = Model( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + source_dir=source_dir, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + ) + + t.register( + SUPPORTED_CONTENT_TYPES, + SUPPORTED_RESPONSE_MIME_TYPES, + SUPPORTED_REALTIME_INFERENCE_INSTANCE_TYPES, + SUPPORTED_BATCH_TRANSFORM_INSTANCE_TYPES, + marketplace_cert=True, + description=MODEL_DESCRIPTION, + model_package_name=MODEL_NAME, + validation_specification=VALIDATION_SPECIFICATION, + source_uri=source_uri, + ) + + # check that the kwarg source_uri was passed to the internal method 'get_model_package_args' + assert ( + "source_uri" in get_model_package_args.call_args_list[0][1] + ), "source_uri kwarg was not passed to get_model_package_args" + + # check that the kwarg source_uri is identical to the one passed into the method 'register' + assert ( + source_uri == get_model_package_args.call_args_list[0][1]["source_uri"] + ), """source_uri from model.register method is not identical to source_uri from + get_model_package_args""" + + +def test_register_with_model_data_source_not_supported_for_unversioned_model(sagemaker_session): source_dir = "s3://blah/blah/blah" t = Model( entry_point=ENTRY_POINT_INFERENCE, @@ -1137,7 +1176,7 @@ def test_register_calls_model_data_source_not_supported(sagemaker_session): with pytest.raises( ValueError, - match="SageMaker Model Package currently cannot be created with ModelDataSource.", + match="Un-versioned SageMaker Model Package currently cannot be created with ModelDataSource.", ): t.register( SUPPORTED_CONTENT_TYPES, @@ -1151,6 +1190,51 @@ def test_register_calls_model_data_source_not_supported(sagemaker_session): ) +@patch("sagemaker.get_model_package_args") +def test_register_with_model_data_source_supported_for_versioned_model( + get_model_package_args, sagemaker_session +): + source_dir = "s3://blah/blah/blah" + model_data_source = { + "S3DataSource": { + "S3Uri": "s3://bucket/model/prefix/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + t = Model( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + source_dir=source_dir, + image_uri=IMAGE_URI, + model_data=model_data_source, + ) + + t.register( + SUPPORTED_CONTENT_TYPES, + SUPPORTED_RESPONSE_MIME_TYPES, + SUPPORTED_REALTIME_INFERENCE_INSTANCE_TYPES, + SUPPORTED_BATCH_TRANSFORM_INSTANCE_TYPES, + marketplace_cert=True, + description=MODEL_DESCRIPTION, + model_package_group_name="dummy_group", + validation_specification=VALIDATION_SPECIFICATION, + ) + + # check that the kwarg container_def_list was set for the internal method 'get_model_package_args' + assert ( + "container_def_list" in get_model_package_args.call_args_list[0][1] + ), "container_def_list kwarg was not set to get_model_package_args" + + # check that the kwarg container in container_def_list contains the model data source + assert ( + model_data_source + == get_model_package_args.call_args_list[0][1]["container_def_list"][0]["ModelDataSource"] + ), """model_data_source from model.register method is not identical to ModelDataSource from + get_model_package_args""" + + @patch("sagemaker.utils.repack_model") def test_model_local_download_dir(repack_model, sagemaker_session): diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py index def7ddf5e3..9bfc830a75 100644 --- a/tests/unit/sagemaker/model/test_model_package.py +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -223,22 +223,21 @@ def test_create_sagemaker_model_include_tags(sagemaker_session): ) -def test_model_package_model_data_source_not_supported(sagemaker_session): - with pytest.raises( - ValueError, match="Creating ModelPackage with ModelDataSource is currently not supported" - ): - ModelPackage( - role="role", - model_package_arn="my-model-package", - model_data={ - "S3DataSource": { - "S3Uri": "s3://bucket/model/prefix/", - "S3DataType": "S3Prefix", - "CompressionType": "None", - } - }, - sagemaker_session=sagemaker_session, - ) +def test_model_package_model_data_source_supported(sagemaker_session): + model_data_source = { + "S3DataSource": { + "S3Uri": "s3://bucket/model/prefix/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + model_package = ModelPackage( + role="role", + model_package_arn="my-model-package", + model_data=model_data_source, + sagemaker_session=sagemaker_session, + ) + assert model_package.model_data == model_package.model_data @patch("sagemaker.utils.name_from_base") @@ -399,3 +398,47 @@ def test_add_inference_specification(sagemaker_session): } ], ) + + +def test_update_inference_specification(sagemaker_session): + model_package = ModelPackage( + role="role", + model_package_arn=MODEL_PACKAGE_VERSIONED_ARN, + sagemaker_session=sagemaker_session, + ) + + image_uris = ["image_uri"] + + containers = [{"Image": "image_uri"}] + + try: + model_package.update_inference_specification(image_uris=image_uris, containers=containers) + except ValueError as ve: + assert "Should have either containers or image_uris for inference." in str(ve) + + try: + model_package.update_inference_specification() + except ValueError as ve: + assert "Should have either containers or image_uris for inference." in str(ve) + + model_package.update_inference_specification(image_uris=image_uris) + + sagemaker_session.sagemaker_client.update_model_package.assert_called_with( + ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, + InferenceSpecification={ + "Containers": [{"Image": "image_uri"}], + }, + ) + + +def test_update_source_uri(sagemaker_session): + source_uri = "dummy_source_uri" + model_package = ModelPackage( + role="role", + model_package_arn=MODEL_PACKAGE_VERSIONED_ARN, + sagemaker_session=sagemaker_session, + ) + model_package.update_source_uri(source_uri=source_uri) + sagemaker_session.sagemaker_client.update_model_package.assert_called_with( + ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, SourceUri=source_uri + ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 9ea40cf140..54fa5f2595 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -5119,6 +5119,233 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_session) sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) +def test_create_model_package_from_containers_with_source_uri_and_inference_spec(sagemaker_session): + model_package_group_name = "sagemaker-model-package-group" + containers = ["dummy-container"] + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarget"] + marketplace_cert = (False,) + approval_status = ("Approved",) + skip_model_validation = "All" + source_uri = "dummy-source-uri" + + created_versioned_mp_arn = ( + "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1" + ) + sagemaker_session.sagemaker_client.create_model_package = Mock( + return_value={"ModelPackageArn": created_versioned_mp_arn} + ) + + sagemaker_session.create_model_package_from_containers( + model_package_group_name=model_package_group_name, + containers=containers, + content_types=content_types, + response_types=response_types, + inference_instances=inference_instances, + transform_instances=transform_instances, + marketplace_cert=marketplace_cert, + approval_status=approval_status, + skip_model_validation=skip_model_validation, + source_uri=source_uri, + ) + expected_create_mp_args = { + "ModelPackageGroupName": model_package_group_name, + "InferenceSpecification": { + "Containers": containers, + "SupportedContentTypes": content_types, + "SupportedResponseMIMETypes": response_types, + "SupportedRealtimeInferenceInstanceTypes": inference_instances, + "SupportedTransformInstanceTypes": transform_instances, + }, + "CertifyForMarketplace": marketplace_cert, + "ModelApprovalStatus": approval_status, + "SkipModelValidation": skip_model_validation, + } + + sagemaker_session.sagemaker_client.create_model_package.assert_called_once_with( + **expected_create_mp_args + ) + expected_update_mp_args = { + "ModelPackageArn": created_versioned_mp_arn, + "SourceUri": source_uri, + } + sagemaker_session.sagemaker_client.update_model_package.assert_called_once_with( + **expected_update_mp_args + ) + + +def test_create_model_package_from_containers_with_source_uri_for_unversioned_mp(sagemaker_session): + model_package_name = "sagemaker-model-package" + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarget"] + marketplace_cert = (False,) + approval_status = ("Approved",) + skip_model_validation = "All" + source_uri = "dummy-source-uri" + + with pytest.raises( + ValueError, + match="Un-versioned SageMaker Model Package currently cannot be created with source_uri.", + ): + sagemaker_session.create_model_package_from_containers( + model_package_name=model_package_name, + content_types=content_types, + response_types=response_types, + inference_instances=inference_instances, + transform_instances=transform_instances, + marketplace_cert=marketplace_cert, + approval_status=approval_status, + skip_model_validation=skip_model_validation, + source_uri=source_uri, + ) + + +def test_create_model_package_from_containers_with_source_uri_for_versioned_mp(sagemaker_session): + model_package_name = "sagemaker-model-package" + model_data_source = { + "S3DataSource": { + "S3Uri": "s3://bucket/model/prefix/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + containers = [{"Image": "dummy-image", "ModelDataSource": model_data_source}] + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarget"] + marketplace_cert = (False,) + approval_status = ("Approved",) + skip_model_validation = "All" + + with pytest.raises( + ValueError, + match="Un-versioned SageMaker Model Package currently cannot be created with ModelDataSource.", + ): + sagemaker_session.create_model_package_from_containers( + model_package_name=model_package_name, + containers=containers, + content_types=content_types, + response_types=response_types, + inference_instances=inference_instances, + transform_instances=transform_instances, + marketplace_cert=marketplace_cert, + approval_status=approval_status, + skip_model_validation=skip_model_validation, + ) + + +def test_create_model_package_from_containers_with_source_uri_set_to_mp(sagemaker_session): + model_package_group_name = "sagemaker-model-package-group" + containers = ["dummy-container"] + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarget"] + marketplace_cert = (False,) + approval_status = ("Approved",) + skip_model_validation = "All" + source_uri = "arn:aws:sagemaker:us-west-2:123456789123:model-package/existing-mp" + + created_versioned_mp_arn = ( + "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1" + ) + sagemaker_session.sagemaker_client.create_model_package = Mock( + return_value={"ModelPackageArn": created_versioned_mp_arn} + ) + + sagemaker_session.create_model_package_from_containers( + model_package_group_name=model_package_group_name, + containers=containers, + content_types=content_types, + response_types=response_types, + inference_instances=inference_instances, + transform_instances=transform_instances, + marketplace_cert=marketplace_cert, + approval_status=approval_status, + skip_model_validation=skip_model_validation, + source_uri=source_uri, + ) + expected_create_mp_args = { + "ModelPackageGroupName": model_package_group_name, + "CertifyForMarketplace": marketplace_cert, + "ModelApprovalStatus": approval_status, + "SkipModelValidation": skip_model_validation, + "SourceUri": source_uri, + } + + sagemaker_session.sagemaker_client.create_model_package.assert_called_once_with( + **expected_create_mp_args + ) + sagemaker_session.sagemaker_client.update_model_package.assert_not_called() + + +def test_create_model_package_from_algorithm_with_model_data_source(sagemaker_session): + model_package_name = "sagemaker-model-package" + description = "dummy description" + model_data_source = { + "S3DataSource": { + "S3Uri": "s3://bucket/model/prefix/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + algorithm_arn = "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees" + sagemaker_session.create_model_package_from_algorithm( + algorithm_arn=algorithm_arn, + model_data=model_data_source, + name=model_package_name, + description=description, + ) + expected_create_mp_args = { + "ModelPackageName": model_package_name, + "ModelPackageDescription": description, + "SourceAlgorithmSpecification": { + "SourceAlgorithms": [ + { + "AlgorithmName": algorithm_arn, + "ModelDataSource": model_data_source, + } + ] + }, + } + sagemaker_session.sagemaker_client.create_model_package.assert_called_once_with( + **expected_create_mp_args + ) + + +def test_create_model_package_from_algorithm_with_model_data_url(sagemaker_session): + model_package_name = "sagemaker-model-package" + description = "dummy description" + model_data_url = "s3://bucket/key" + algorithm_arn = "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees" + sagemaker_session.create_model_package_from_algorithm( + algorithm_arn=algorithm_arn, + model_data=model_data_url, + name=model_package_name, + description=description, + ) + expected_create_mp_args = { + "ModelPackageName": model_package_name, + "ModelPackageDescription": description, + "SourceAlgorithmSpecification": { + "SourceAlgorithms": [ + { + "AlgorithmName": algorithm_arn, + "ModelDataUrl": model_data_url, + } + ] + }, + } + sagemaker_session.sagemaker_client.create_model_package.assert_called_once_with( + **expected_create_mp_args + ) + + def test_create_model_package_from_containers_all_args(sagemaker_session): model_package_name = "sagemaker-model-package" containers = ["dummy-container"] diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 852ef8b153..a83f1b995d 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -46,6 +46,7 @@ _is_bad_path, _is_bad_link, custom_extractall_tarfile, + can_model_package_source_uri_autopopulate, ) from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.parameters import ParameterString, ParameterInteger @@ -1796,3 +1797,15 @@ def test_is_bad_link(link_name, base, expected): def test_custom_extractall_tarfile(mock_custom_tarfile, data_filter, expected_extract_path): tar = mock_custom_tarfile(data_filter) custom_extractall_tarfile(tar, "/extract/path") + + +def test_can_model_package_source_uri_autopopulate(): + test_data = [ + ("arn:aws:sagemaker:us-west-2:012345678912:model-package/dummy-mpg/1", True), + ("arn:aws:sagemaker:us-west-2:012345678912:model-package/dummy-mp", True), + ("arn:aws:sagemaker:us-west-2:012345678912:model/dummy-model", True), + ("https://path/to/model", False), + ("/home/path/to/model", False), + ] + for source_uri, expected in test_data: + assert can_model_package_source_uri_autopopulate(source_uri) == expected From 525e9aedf1f941c5e2c0cb408712139e58985a1b Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Tue, 12 Mar 2024 16:08:33 -0400 Subject: [PATCH 30/42] feat: support JumpStart proprietary models (#4467) * feat: add proprietary manifest/specs parsing add unittests for test_cache small refactoring address comments and more unittests fix linting and fix more tests fix: pylint feat: JumpStartModel class for prop models * remove unused imports and fix docstyle * fix: remove unused args * fix: remove unused args * fix: more unused vars * fix: slow tests * fix: unittests * added more tests to cover some lines * remove estimator warn check * chore: address comments re performance * fix: address comments * complete list experience and other fixes * fix: pylint * add doc utils and fix pylint * fix: docstyle * fix: doc * fix: default payloads * fix: doc and tags and enums * fix: jumpstart doc * rename to open_weights and fix filtering * update filter name * doc update * fix: black * rename to proprietary model and fix unittests * address comments * fix: docstyle and flake8 * address more comments and fix doc * put back doc utils for future refactoring * add prop model title in doc * doc update --------- Co-authored-by: liujiaor <128006184+liujiaorr@users.noreply.github.com> --- doc/doc_utils/jumpstart_doc_utils.py | 83 ++++- src/sagemaker/accept_types.py | 3 + src/sagemaker/base_predictor.py | 4 +- src/sagemaker/content_types.py | 3 + src/sagemaker/deserializers.py | 3 + src/sagemaker/instance_types.py | 3 + src/sagemaker/jumpstart/accessors.py | 24 +- .../jumpstart/artifacts/instance_types.py | 3 + src/sagemaker/jumpstart/artifacts/kwargs.py | 5 + .../jumpstart/artifacts/model_packages.py | 3 + src/sagemaker/jumpstart/artifacts/payloads.py | 3 + .../jumpstart/artifacts/predictors.py | 15 + .../jumpstart/artifacts/resource_names.py | 3 + .../artifacts/resource_requirements.py | 3 + src/sagemaker/jumpstart/cache.py | 255 +++++++++++---- src/sagemaker/jumpstart/constants.py | 17 +- src/sagemaker/jumpstart/enums.py | 14 + src/sagemaker/jumpstart/estimator.py | 13 +- src/sagemaker/jumpstart/exceptions.py | 70 +++- src/sagemaker/jumpstart/factory/estimator.py | 4 +- src/sagemaker/jumpstart/factory/model.py | 48 ++- src/sagemaker/jumpstart/filters.py | 24 ++ src/sagemaker/jumpstart/model.py | 74 ++++- src/sagemaker/jumpstart/notebook_utils.py | 56 +++- src/sagemaker/jumpstart/types.py | 54 ++- src/sagemaker/jumpstart/utils.py | 56 +++- src/sagemaker/model.py | 2 +- src/sagemaker/payloads.py | 7 + src/sagemaker/predictor.py | 3 + src/sagemaker/resource_requirements.py | 3 + src/sagemaker/serializers.py | 3 + .../jumpstart/model/test_jumpstart_model.py | 26 +- .../jumpstart/test_accept_types.py | 28 +- .../jumpstart/test_content_types.py | 27 +- .../jumpstart/test_deserializers.py | 27 +- .../jumpstart/test_default.py | 42 ++- .../hyperparameters/jumpstart/test_default.py | 28 +- .../jumpstart/test_validate.py | 41 ++- .../image_uris/jumpstart/test_common.py | 11 +- .../jumpstart/test_instance_types.py | 9 +- tests/unit/sagemaker/jumpstart/constants.py | 81 +++++ .../jumpstart/estimator/test_estimator.py | 158 ++++----- .../estimator/test_sagemaker_config.py | 49 +-- .../sagemaker/jumpstart/model/test_model.py | 218 ++++++++----- .../jumpstart/model/test_sagemaker_config.py | 49 +-- .../sagemaker/jumpstart/test_accessors.py | 86 ++++- .../sagemaker/jumpstart/test_artifacts.py | 14 +- tests/unit/sagemaker/jumpstart/test_cache.py | 308 ++++++++++++++++-- .../sagemaker/jumpstart/test_exceptions.py | 34 ++ .../jumpstart/test_notebook_utils.py | 114 +++++-- .../sagemaker/jumpstart/test_predictor.py | 52 ++- tests/unit/sagemaker/jumpstart/test_utils.py | 94 ++++-- tests/unit/sagemaker/jumpstart/utils.py | 59 +++- .../jumpstart/test_default.py | 21 +- .../model_uris/jumpstart/test_common.py | 11 +- .../jumpstart/test_resource_requirements.py | 28 +- .../script_uris/jumpstart/test_common.py | 11 +- .../serializers/jumpstart/test_serializers.py | 27 +- 58 files changed, 2014 insertions(+), 500 deletions(-) diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index 348de7adeb..458da694d5 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -74,9 +74,12 @@ class Frameworks(str, Enum): JUMPSTART_REGION = "eu-west-2" SDK_MANIFEST_FILE = "models_manifest.json" +PROPRIETARY_SDK_MANIFEST_FILE = "proprietary-sdk-manifest.json" JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com".format( JUMPSTART_REGION, JUMPSTART_REGION ) +PROPRIETARY_DOC_BUCKET = "https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com" + TASK_MAP = { Tasks.IC: ProblemTypes.IMAGE_CLASSIFICATION, Tasks.IC_EMBEDDING: ProblemTypes.IMAGE_EMBEDDING, @@ -152,18 +155,26 @@ class Frameworks(str, Enum): } -def get_jumpstart_sdk_manifest(): - url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, SDK_MANIFEST_FILE) +def get_public_s3_json_object(url): with request.urlopen(url) as f: models_manifest = f.read().decode("utf-8") return json.loads(models_manifest) -def get_jumpstart_sdk_spec(key): - url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, key) - with request.urlopen(url) as f: - model_spec = f.read().decode("utf-8") - return json.loads(model_spec) +def get_jumpstart_sdk_manifest(): + return get_public_s3_json_object(f"{JUMPSTART_BUCKET_BASE_URL}/{SDK_MANIFEST_FILE}") + + +def get_proprietary_sdk_manifest(): + return get_public_s3_json_object(f"{PROPRIETARY_DOC_BUCKET}/{PROPRIETARY_SDK_MANIFEST_FILE}") + + +def get_jumpstart_sdk_spec(s3_key: str): + return get_public_s3_json_object(f"{JUMPSTART_BUCKET_BASE_URL}/{s3_key}") + + +def get_proprietary_sdk_spec(s3_key: str): + return get_public_s3_json_object(f"{PROPRIETARY_DOC_BUCKET}/{s3_key}") def get_model_task(id): @@ -196,6 +207,45 @@ def get_model_source(url): return "Source" +def create_proprietary_model_table(): + proprietary_content_intro = [] + proprietary_content_intro.append("\n") + proprietary_content_intro.append(".. list-table:: Available Proprietary Models\n") + proprietary_content_intro.append(" :widths: 50 20 20 20 20\n") + proprietary_content_intro.append(" :header-rows: 1\n") + proprietary_content_intro.append(" :class: datatable\n") + proprietary_content_intro.append("\n") + proprietary_content_intro.append(" * - Model ID\n") + proprietary_content_intro.append(" - Fine Tunable?\n") + proprietary_content_intro.append(" - Supported Version\n") + proprietary_content_intro.append(" - Min SDK Version\n") + proprietary_content_intro.append(" - Source\n") + + sdk_manifest = get_proprietary_sdk_manifest() + sdk_manifest_top_versions_for_models = {} + + for model in sdk_manifest: + if model["model_id"] not in sdk_manifest_top_versions_for_models: + sdk_manifest_top_versions_for_models[model["model_id"]] = model + else: + if str(sdk_manifest_top_versions_for_models[model["model_id"]]["version"]) < str( + model["version"] + ): + sdk_manifest_top_versions_for_models[model["model_id"]] = model + + proprietary_content_entries = [] + for model in sdk_manifest_top_versions_for_models.values(): + model_spec = get_proprietary_sdk_spec(model["spec_key"]) + proprietary_content_entries.append(" * - {}\n".format(model_spec["model_id"])) + proprietary_content_entries.append(" - {}\n".format(False)) # TODO: support training + proprietary_content_entries.append(" - {}\n".format(model["version"])) + proprietary_content_entries.append(" - {}\n".format(model["min_version"])) + proprietary_content_entries.append( + " - `{} <{}>`__ |external-link|\n".format("Source", model_spec.get("url")) + ) + return proprietary_content_intro + proprietary_content_entries + ["\n"] + + def create_jumpstart_model_table(): sdk_manifest = get_jumpstart_sdk_manifest() sdk_manifest_top_versions_for_models = {} @@ -249,19 +299,19 @@ def create_jumpstart_model_table(): file_content_intro.append(" - Source\n") dynamic_table_files = [] - file_content_entries = [] + open_weight_content_entries = [] for model in sdk_manifest_top_versions_for_models.values(): model_spec = get_jumpstart_sdk_spec(model["spec_key"]) model_task = get_model_task(model_spec["model_id"]) string_model_task = get_string_model_task(model_spec["model_id"]) model_source = get_model_source(model_spec["url"]) - file_content_entries.append(" * - {}\n".format(model_spec["model_id"])) - file_content_entries.append(" - {}\n".format(model_spec["training_supported"])) - file_content_entries.append(" - {}\n".format(model["version"])) - file_content_entries.append(" - {}\n".format(model["min_version"])) - file_content_entries.append(" - {}\n".format(model_task)) - file_content_entries.append( + open_weight_content_entries.append(" * - {}\n".format(model_spec["model_id"])) + open_weight_content_entries.append(" - {}\n".format(model_spec["training_supported"])) + open_weight_content_entries.append(" - {}\n".format(model["version"])) + open_weight_content_entries.append(" - {}\n".format(model["min_version"])) + open_weight_content_entries.append(" - {}\n".format(model_task)) + open_weight_content_entries.append( " - `{} <{}>`__ |external-link|\n".format(model_source, model_spec["url"]) ) @@ -299,7 +349,10 @@ def create_jumpstart_model_table(): f.writelines(file_content_single_entry) f.close() + proprietary_content_entries = create_proprietary_model_table() + f = open("doc_utils/pretrainedmodels.rst", "a") f.writelines(file_content_intro) - f.writelines(file_content_entries) + f.writelines(open_weight_content_entries) + f.writelines(proprietary_content_entries) f.close() diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index 14212fd991..0327ef3845 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -16,6 +16,7 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -80,6 +81,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the default accept type for the model matching the given arguments. @@ -122,4 +124,5 @@ def retrieve_default( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) diff --git a/src/sagemaker/base_predictor.py b/src/sagemaker/base_predictor.py index 882cfafc39..76b83c25cd 100644 --- a/src/sagemaker/base_predictor.py +++ b/src/sagemaker/base_predictor.py @@ -58,7 +58,9 @@ from sagemaker.model_monitor.model_monitoring import DEFAULT_REPOSITORY_NAME from sagemaker.lineage.context import EndpointContext -from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements +from sagemaker.compute_resource_requirements.resource_requirements import ( + ResourceRequirements, +) LOGGER = logging.getLogger("sagemaker") diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index 5e82201c31..3154c1e4fe 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -16,6 +16,7 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -80,6 +81,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the default content type for the model matching the given arguments. @@ -122,6 +124,7 @@ def retrieve_default( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 7bb08ce15a..3081daea23 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -35,6 +35,7 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -100,6 +101,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> BaseDeserializer: """Retrieves the default deserializer for the model matching the given arguments. @@ -143,4 +145,5 @@ def retrieve_default( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index d277f8cf3b..fecd769011 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -20,6 +20,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -35,6 +36,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the default instance type for the model matching the given arguments. @@ -89,6 +91,7 @@ def retrieve_default( tolerate_deprecated_model, sagemaker_session=sagemaker_session, training_instance_type=training_instance_type, + model_type=model_type, ) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 456ba6baf6..dfc833ec28 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -18,6 +18,7 @@ from sagemaker.deprecations import deprecated from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart import cache from sagemaker.jumpstart.curated_hub.utils import construct_hub_model_arn_from_inputs from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME @@ -198,7 +199,9 @@ def _set_cache_and_region(region: str, cache_kwargs: dict) -> None: @staticmethod def _get_manifest( - region: str = JUMPSTART_DEFAULT_REGION_NAME, s3_client: Optional[boto3.client] = None + region: str = JUMPSTART_DEFAULT_REGION_NAME, + s3_client: Optional[boto3.client] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest. @@ -216,13 +219,19 @@ def _get_manifest( additional_kwargs.update({"s3_client": s3_client}) cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( - {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}, region + {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}, + region, ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) - return JumpStartModelsAccessor._cache.get_manifest() # type: ignore + return JumpStartModelsAccessor._cache.get_manifest(model_type) # type: ignore @staticmethod - def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader: + def get_model_header( + region: str, + model_id: str, + version: str, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + ) -> JumpStartModelHeader: """Returns model header from JumpStart models cache. Args: @@ -235,7 +244,9 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) return JumpStartModelsAccessor._cache.get_header( # type: ignore - model_id=model_id, semantic_version_str=version + model_id=model_id, + semantic_version_str=version, + model_type=model_type, ) @staticmethod @@ -245,6 +256,7 @@ def get_model_specs( version: str, hub_arn: Optional[str] = None, s3_client: Optional[boto3.client] = None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """Returns model specs from JumpStart models cache. @@ -272,7 +284,7 @@ def get_model_specs( return JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn=hub_model_arn) return JumpStartModelsAccessor._cache.get_specs( # type: ignore - model_id=model_id, semantic_version_str=version + model_id=model_id, version_str=version, model_type=model_type ) @staticmethod diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 5f252f00ad..24a44cf548 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -22,6 +22,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -39,6 +40,7 @@ def _retrieve_default_instance_type( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the default instance type for the model. @@ -88,6 +90,7 @@ def _retrieve_default_instance_type( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + model_type=model_type, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index c9edeb2e76..f7e3ff3d96 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -22,6 +22,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -36,6 +37,7 @@ def _retrieve_model_init_kwargs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> dict: """Retrieves kwargs for `Model`. @@ -75,6 +77,7 @@ def _retrieve_model_init_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) kwargs = deepcopy(model_specs.model_kwargs) @@ -94,6 +97,7 @@ def _retrieve_model_deploy_kwargs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> dict: """Retrieves kwargs for `Model.deploy`. @@ -136,6 +140,7 @@ def _retrieve_model_deploy_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) if volume_size_supported(instance_type) and model_specs.inference_volume_size is not None: diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index 4a0fc147d5..0a5c3b7bed 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -22,6 +22,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.session import Session @@ -36,6 +37,7 @@ def _retrieve_model_package_arn( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[str]: """Retrieves associated model pacakge arn for the model. @@ -78,6 +80,7 @@ def _retrieve_model_package_arn( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) if scope == JumpStartScriptScope.INFERENCE: diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 21fca2abd2..a89da8d1bd 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -20,6 +20,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.jumpstart.types import JumpStartSerializablePayload from sagemaker.jumpstart.utils import ( @@ -36,6 +37,7 @@ def _retrieve_example_payloads( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[Dict[str, JumpStartSerializablePayload]]: """Returns example payloads. @@ -76,6 +78,7 @@ def _retrieve_example_payloads( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) default_payloads = model_specs.default_payloads diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 35fe4e3dcf..8d7cc26c41 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -26,6 +26,7 @@ from sagemaker.jumpstart.enums import ( JumpStartScriptScope, MIMEType, + JumpStartModelType, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -77,6 +78,7 @@ def _retrieve_default_deserializer( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> BaseDeserializer: """Retrieves the default deserializer for the model. @@ -112,6 +114,7 @@ def _retrieve_default_deserializer( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) return _retrieve_deserializer_from_accept_type(MIMEType.from_suffixed_type(default_accept_type)) @@ -125,6 +128,7 @@ def _retrieve_default_serializer( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> BaseSerializer: """Retrieves the default serializer for the model. @@ -159,6 +163,7 @@ def _retrieve_default_serializer( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) return _retrieve_serializer_from_content_type(MIMEType.from_suffixed_type(default_content_type)) @@ -172,6 +177,7 @@ def _retrieve_deserializer_options( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[BaseDeserializer]: """Retrieves the supported deserializers for the model. @@ -206,6 +212,7 @@ def _retrieve_deserializer_options( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) seen_classes: Set[Type] = set() @@ -293,6 +300,7 @@ def _retrieve_default_content_type( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the default content type for the model. @@ -332,6 +340,7 @@ def _retrieve_default_content_type( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) default_content_type = model_specs.predictor_specs.default_content_type @@ -346,6 +355,7 @@ def _retrieve_default_accept_type( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the default accept type for the model. @@ -384,6 +394,7 @@ def _retrieve_default_accept_type( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) default_accept_type = model_specs.predictor_specs.default_accept_type @@ -399,6 +410,7 @@ def _retrieve_supported_accept_types( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[str]: """Retrieves the supported accept types for the model. @@ -437,6 +449,7 @@ def _retrieve_supported_accept_types( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) supported_accept_types = model_specs.predictor_specs.supported_accept_types @@ -452,6 +465,7 @@ def _retrieve_supported_content_types( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[str]: """Retrieves the supported content types for the model. @@ -490,6 +504,7 @@ def _retrieve_supported_content_types( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) supported_content_types = model_specs.predictor_specs.supported_content_types diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index ca3044068b..4edcd43e25 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -19,6 +19,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -33,6 +34,7 @@ def _retrieve_resource_name_base( hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> bool: """Returns default resource name. @@ -72,6 +74,7 @@ def _retrieve_resource_name_base( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + model_type=model_type, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 8baaaafd2a..f2097ef8d4 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -21,6 +21,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -51,6 +52,7 @@ def _retrieve_default_resources( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, ) -> ResourceRequirements: @@ -101,6 +103,7 @@ def _retrieve_default_resources( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + model_type=model_type, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 9d804dc53a..63b47632fd 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -27,12 +27,18 @@ ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, + JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, MODEL_ID_LIST_WEB_URL, DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + MODEL_TYPE_TO_MANIFEST_MAP, + MODEL_TYPE_TO_SPECS_MAP, +) +from sagemaker.jumpstart.exceptions import ( + get_wildcard_model_version_msg, + get_wildcard_proprietary_model_version_msg, ) -from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg from sagemaker.jumpstart.parameters import ( JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, @@ -53,6 +59,7 @@ HubContentType, ) from sagemaker.jumpstart.curated_hub import utils as hub_utils +from sagemaker.jumpstart.enums import JumpStartModelType class JumpStartModelsCache: @@ -75,6 +82,7 @@ def __init__( JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, manifest_file_s3_key: str = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, + proprietary_manifest_s3_key: str = JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, s3_bucket_name: Optional[str] = None, s3_client_config: Optional[botocore.config.Config] = None, s3_client: Optional[boto3.client] = None, @@ -111,14 +119,26 @@ def __init__( expiration_horizon=s3_cache_expiration_horizon, retrieval_function=self._retrieval_function, ) - self._model_id_semantic_version_manifest_key_cache = LRUCache[ + self._open_weight_model_id_manifest_key_cache = LRUCache[ JumpStartVersionedModelId, JumpStartVersionedModelId ]( max_cache_items=max_semantic_version_cache_items, expiration_horizon=semantic_version_cache_expiration_horizon, - retrieval_function=self._get_manifest_key_from_model_id_semantic_version, + retrieval_function=self._get_open_weight_manifest_key_from_model_id, + ) + self._proprietary_model_id_manifest_key_cache = LRUCache[ + JumpStartVersionedModelId, JumpStartVersionedModelId + ]( + max_cache_items=max_semantic_version_cache_items, + expiration_horizon=semantic_version_cache_expiration_horizon, + retrieval_function=self._get_proprietary_manifest_key_from_model_id, ) self._manifest_file_s3_key = manifest_file_s3_key + self._proprietary_manifest_s3_key = proprietary_manifest_s3_key + self._manifest_file_s3_map = { + JumpStartModelType.OPEN_WEIGHTS: self._manifest_file_s3_key, + JumpStartModelType.PROPRIETARY: self._proprietary_manifest_s3_key, + } self.s3_bucket_name = ( utils.get_jumpstart_content_bucket(self._region) if s3_bucket_name is None @@ -141,15 +161,40 @@ def get_region(self) -> str: """Return region for cache.""" return self._region - def set_manifest_file_s3_key(self, key: str) -> None: - """Set manifest file s3 key. Clears cache after new key is set.""" - if key != self._manifest_file_s3_key: - self._manifest_file_s3_key = key + def set_manifest_file_s3_key( + self, + key: str, + file_type: JumpStartS3FileType = JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, + ) -> None: + """Set manifest file s3 key, clear cache after new key is set. + + Raises: + ValueError: if the file type is not recognized + """ + file_mapping = { + JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: self._manifest_file_s3_key, + JumpStartS3FileType.PROPRIETARY_MANIFEST: self._proprietary_manifest_s3_key, + } + property_name = file_mapping.get(file_type) + if not property_name: + raise ValueError( + self._file_type_error_msg(file_type, manifest_only=True) + ) + if key != property_name: + setattr(self, property_name, key) self.clear() - def get_manifest_file_s3_key(self) -> str: + def get_manifest_file_s3_key( + self, file_type: JumpStartS3FileType = JumpStartS3FileType.OPEN_WEIGHT_MANIFEST + ) -> str: """Return manifest file s3 key for cache.""" - return self._manifest_file_s3_key + if file_type == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: + return self._manifest_file_s3_key + if file_type == JumpStartS3FileType.PROPRIETARY_MANIFEST: + return self._proprietary_manifest_s3_key + raise ValueError( + self._file_type_error_msg(file_type, manifest_only=True) + ) def set_s3_bucket_name(self, s3_bucket_name: str) -> None: """Set s3 bucket used for cache.""" @@ -161,10 +206,24 @@ def get_bucket(self) -> str: """Return bucket used for cache.""" return self.s3_bucket_name - def _get_manifest_key_from_model_id_semantic_version( + def _file_type_error_msg(self, file_type: str, manifest_only: bool = False) -> str: + """Return error message for bad model type.""" + if manifest_only: + return ( + f"Bad value when getting manifest '{file_type}': " + f"must be in {JumpStartS3FileType.OPEN_WEIGHT_MANIFEST} " + f"{JumpStartS3FileType.PROPRIETARY_MANIFEST}." + ) + return ( + f"Bad value when getting manifest '{file_type}': " + f"must be in '{' '.join([e.name for e in JumpStartS3FileType])}'." + ) + + def _model_id_retrieval_function( self, key: JumpStartVersionedModelId, value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 + model_type: JumpStartModelType, ) -> JumpStartVersionedModelId: """Return model ID and version in manifest that matches semantic version/id. @@ -176,6 +235,8 @@ def _get_manifest_key_from_model_id_semantic_version( key (JumpStartVersionedModelId): Key for which to fetch versioned model ID. value (Optional[JumpStartVersionedModelId]): Unused variable for current value of old cached model ID/version. + model_type (JumpStartModelType): JumpStart model type to indicate whether it is + open weights model or proprietary (Marketplace) model. Raises: KeyError: If the semantic version is not found in the manifest, or is found but @@ -183,21 +244,20 @@ def _get_manifest_key_from_model_id_semantic_version( """ model_id, version = key.model_id, key.version - + sm_version = utils.get_sagemaker_version() manifest = self._content_cache.get( - JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + JumpStartCachedContentKey( + MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]) )[0].formatted_content - sm_version = utils.get_sagemaker_version() - versions_compatible_with_sagemaker = [ - Version(header.version) + header.version for header in manifest.values() # type: ignore if header.model_id == model_id and Version(header.min_version) <= Version(sm_version) ] sm_compatible_model_version = self._select_version( - version, versions_compatible_with_sagemaker + model_id, version, versions_compatible_with_sagemaker, model_type ) if sm_compatible_model_version is not None: @@ -208,7 +268,7 @@ def _get_manifest_key_from_model_id_semantic_version( if header.model_id == model_id ] sm_incompatible_model_version = self._select_version( - version, versions_incompatible_with_sagemaker + model_id, version, versions_incompatible_with_sagemaker, model_type ) if sm_incompatible_model_version is not None: @@ -238,15 +298,27 @@ def _get_manifest_key_from_model_id_semantic_version( f"Visit {MODEL_ID_LIST_WEB_URL} for updated list of models. " ) - other_model_id_version = self._select_version( - "*", versions_incompatible_with_sagemaker - ) # all versions here are incompatible with sagemaker + other_model_id_version = None + if model_type == JumpStartModelType.OPEN_WEIGHTS: + other_model_id_version = self._select_version( + model_id, "*", versions_incompatible_with_sagemaker, model_type + ) # all versions here are incompatible with sagemaker + elif model_type == JumpStartModelType.PROPRIETARY: + all_possible_model_id_version = [ + header.version for header in manifest.values() # type: ignore + if header.model_id == model_id + ] + other_model_id_version = ( + None + if not all_possible_model_id_version + else all_possible_model_id_version[0] + ) + if other_model_id_version is not None: error_msg += ( f"Consider using model ID '{model_id}' with version " f"'{other_model_id_version}'." ) - else: possible_model_ids = [header.model_id for header in manifest.values()] # type: ignore closest_model_id = get_close_matches(model_id, possible_model_ids, n=1, cutoff=0)[0] @@ -254,6 +326,32 @@ def _get_manifest_key_from_model_id_semantic_version( raise KeyError(error_msg) + def _get_open_weight_manifest_key_from_model_id( + self, + key: JumpStartVersionedModelId, + value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 + ) -> JumpStartVersionedModelId: + """For open weights models, retrieve model manifest key for open weight model. + + Filters models list by supported versions. + """ + return self._model_id_retrieval_function( + key, value, model_type=JumpStartModelType.OPEN_WEIGHTS + ) + + def _get_proprietary_manifest_key_from_model_id( + self, + key: JumpStartVersionedModelId, + value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 + ) -> JumpStartVersionedModelId: + """For proprietary models, retrieve model manifest key for proprietary model. + + Filters models list by supported versions. + """ + return self._model_id_retrieval_function( + key, value, model_type=JumpStartModelType.PROPRIETARY + ) + def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list], str]: """Returns json file from s3, along with its etag.""" response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=key) @@ -298,11 +396,11 @@ def _get_json_file_from_local_override( filetype: JumpStartS3FileType ) -> Union[dict, list]: """Reads json file from local filesystem and returns data.""" - if filetype == JumpStartS3FileType.MANIFEST: + if filetype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: metadata_local_root = ( os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE] ) - elif filetype == JumpStartS3FileType.SPECS: + elif filetype == JumpStartS3FileType.OPEN_WEIGHT_SPECS: metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE] else: raise ValueError(f"Unsupported file type for local override: {filetype}") @@ -330,8 +428,10 @@ def _retrieval_function( """ data_type, id_info = key.data_type, key.id_info - - if data_type == JumpStartS3FileType.MANIFEST: + if data_type in { + JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, + JumpStartS3FileType.PROPRIETARY_MANIFEST, + }: if value is not None and not self._is_local_metadata_mode(): etag = self._get_json_md5_hash(id_info) if etag == value.md5_hash: @@ -341,13 +441,16 @@ def _retrieval_function( formatted_content=utils.get_formatted_manifest(formatted_body), md5_hash=etag, ) - if data_type == JumpStartS3FileType.SPECS: + + if data_type in { + JumpStartS3FileType.OPEN_WEIGHT_SPECS, + JumpStartS3FileType.PROPRIETARY_SPECS, + }: formatted_body, _ = self._get_json_file(id_info, data_type) model_specs = JumpStartModelSpecs(formatted_body) utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client) - return JumpStartCachedContentValue( - formatted_content=model_specs - ) + return JumpStartCachedContentValue(formatted_content=model_specs) + if data_type == HubContentType.MODEL: hub_name, _, model_name, model_version = hub_utils.get_info_from_hub_resource_arn( id_info @@ -371,6 +474,7 @@ def _retrieval_function( return JumpStartCachedContentValue( formatted_content=model_specs ) + if data_type == HubType.HUB: hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info) response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name) @@ -378,21 +482,30 @@ def _retrieval_function( return JumpStartCachedContentValue( formatted_content=DescribeHubResponse(hub_description) ) + + raise ValueError( - f"Bad value for key '{key}': must be in ", - f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubType.HUB, HubContentType.MODEL]}" + self._file_type_error_msg(data_type) ) - def get_manifest(self) -> List[JumpStartModelHeader]: + def get_manifest( + self, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + ) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest.""" - manifest_dict = self._content_cache.get( - JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + JumpStartCachedContentKey( + MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]) )[0].formatted_content manifest = list(manifest_dict.values()) # type: ignore return manifest - def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModelHeader: + def get_header( + self, + model_id: str, + semantic_version_str: str, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + ) -> JumpStartModelHeader: """Return header for a given JumpStart model ID and semantic version. Args: @@ -401,29 +514,43 @@ def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModel header. """ - return self._get_header_impl(model_id, semantic_version_str=semantic_version_str) + return self._get_header_impl( + model_id, semantic_version_str=semantic_version_str, model_type=model_type + ) def _select_version( self, - semantic_version_str: str, - available_versions: List[Version], + model_id: str, + version_str: str, + available_versions: List[str], + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[str]: """Perform semantic version search on available versions. Args: - semantic_version_str (str): the semantic version for which to filter + version_str (str): the semantic version for which to filter available versions. available_versions (List[Version]): list of available versions. """ - if semantic_version_str == "*": + + if version_str == "*": if len(available_versions) == 0: return None return str(max(available_versions)) + if model_type == JumpStartModelType.PROPRIETARY: + if "*" in version_str: + raise KeyError( + get_wildcard_proprietary_model_version_msg( + model_id, version_str, available_versions + ) + ) + return version_str if version_str in available_versions else None + try: - spec = SpecifierSet(f"=={semantic_version_str}") + spec = SpecifierSet(f"=={version_str}") except InvalidSpecifier: - raise KeyError(f"Bad semantic version: {semantic_version_str}") + raise KeyError(f"Bad semantic version: {version_str}") available_versions_filtered = list(spec.filter(available_versions)) return ( str(max(available_versions_filtered)) if available_versions_filtered != [] else None @@ -434,6 +561,7 @@ def _get_header_impl( model_id: str, semantic_version_str: str, attempt: int = 0, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS ) -> JumpStartModelHeader: """Lower-level function to return header. @@ -445,14 +573,20 @@ def _get_header_impl( header. attempt (int): attempt number at retrieving a header. """ - - versioned_model_id = self._model_id_semantic_version_manifest_key_cache.get( - JumpStartVersionedModelId(model_id, semantic_version_str) - )[0] + if model_type == JumpStartModelType.OPEN_WEIGHTS: + versioned_model_id = self._open_weight_model_id_manifest_key_cache.get( + JumpStartVersionedModelId(model_id, semantic_version_str) + )[0] + elif model_type == JumpStartModelType.PROPRIETARY: + versioned_model_id = self._proprietary_model_id_manifest_key_cache.get( + JumpStartVersionedModelId(model_id, semantic_version_str) + )[0] manifest = self._content_cache.get( - JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + JumpStartCachedContentKey( + MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]) )[0].formatted_content + try: header = manifest[versioned_model_id] # type: ignore return header @@ -460,28 +594,34 @@ def _get_header_impl( if attempt > 0: raise self.clear() - return self._get_header_impl(model_id, semantic_version_str, attempt + 1) + return self._get_header_impl(model_id, semantic_version_str, attempt + 1, model_type) - def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelSpecs: + def get_specs( + self, + model_id: str, + version_str: str, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS + ) -> JumpStartModelSpecs: """Return specs for a given JumpStart model ID and semantic version. Args: model_id (str): model ID for which to get specs. semantic_version_str (str): The semantic version for which to get specs. + model_type (JumpStartModelType): The type of the model of interest. """ - - header = self.get_header(model_id, semantic_version_str) + header = self.get_header(model_id, version_str, model_type) spec_key = header.spec_key specs, cache_hit = self._content_cache.get( - JumpStartCachedContentKey(JumpStartS3FileType.SPECS, spec_key) + JumpStartCachedContentKey( + MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key + ) ) - if not cache_hit and "*" in semantic_version_str: + + if not cache_hit and "*" in version_str: JUMPSTART_LOGGER.warning( get_wildcard_model_version_msg( - header.model_id, - semantic_version_str, - header.version + header.model_id, version_str, header.version ) ) return specs.formatted_content @@ -511,4 +651,5 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]: def clear(self) -> None: """Clears the model ID/version and s3 cache.""" self._content_cache.clear() - self._model_id_semantic_version_manifest_key_cache.clear() + self._open_weight_model_id_manifest_key_cache.clear() + self._proprietary_model_id_manifest_key_cache.clear() diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 1fec16b9b3..1b679d44f6 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -22,8 +22,9 @@ SerializerType, DeserializerType, MIMEType, + JumpStartModelType, ) -from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo +from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo, JumpStartS3FileType from sagemaker.base_serializers import ( BaseSerializer, CSVSerializer, @@ -169,6 +170,7 @@ JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2" JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" +JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json" HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$" HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" @@ -191,6 +193,9 @@ SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY = "SageMakerGatedModelS3Uri" +PROPRIETARY_MODEL_SPEC_PREFIX = "proprietary-models" +PROPRIETARY_MODEL_FILTER_NAME = "marketplace" + CONTENT_TYPE_TO_SERIALIZER_TYPE_MAP: Dict[MIMEType, SerializerType] = { MIMEType.X_IMAGE: SerializerType.RAW_BYTES, MIMEType.LIST_TEXT: SerializerType.JSON, @@ -216,6 +221,16 @@ DeserializerType.JSON: JSONDeserializer, } +MODEL_TYPE_TO_MANIFEST_MAP: Dict[Type[JumpStartModelType], Type[JumpStartS3FileType]] = { + JumpStartModelType.OPEN_WEIGHTS: JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, + JumpStartModelType.PROPRIETARY: JumpStartS3FileType.PROPRIETARY_MANIFEST, +} + +MODEL_TYPE_TO_SPECS_MAP: Dict[Type[JumpStartModelType], Type[JumpStartS3FileType]] = { + JumpStartModelType.OPEN_WEIGHTS: JumpStartS3FileType.OPEN_WEIGHT_SPECS, + JumpStartModelType.PROPRIETARY: JumpStartS3FileType.PROPRIETARY_SPECS, +} + MODEL_ID_LIST_WEB_URL = "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html" JUMPSTART_LOGGER = logging.getLogger("sagemaker.jumpstart") diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index f962fdca80..c26258c7d7 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -34,6 +34,19 @@ class ModelFramework(str, Enum): SKLEARN = "sklearn" +class JumpStartModelType(str, Enum): + """Enum class for JumpStart model type. + + OPEN_WEIGHTS: Publicly available models have open weights + and are onboarded and maintained by JumpStart. + PROPRIETARY: Proprietary models from third-party providers do not have open weights. + You must subscribe to proprietary models in AWS Marketplace before use. + """ + + OPEN_WEIGHTS = "open_weights" + PROPRIETARY = "proprietary" + + class VariableScope(str, Enum): """Possible value of the ``scope`` attribute for a hyperparameter or environment variable. @@ -78,6 +91,7 @@ class JumpStartTag(str, Enum): MODEL_ID = "sagemaker-sdk:jumpstart-model-id" MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version" + MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type" HUB_ARN = "sagemaker-hub:hub-arn" diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index d0706a9aab..6406932924 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -36,7 +36,7 @@ from sagemaker.jumpstart.factory.model import get_default_predictor from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job from sagemaker.jumpstart.utils import ( - is_valid_model_id, + validate_model_id_and_get_type, resolve_model_sagemaker_config_field, ) from sagemaker.utils import stringify_object, format_tags, Tags @@ -507,8 +507,8 @@ def __init__( ValueError: If the model ID is not recognized by JumpStart. """ - def _is_valid_model_id_hook(): - return is_valid_model_id( + def _validate_model_id_and_get_type_hook(): + return validate_model_id_and_get_type( model_id=model_id, model_version=model_version, region=region, @@ -516,9 +516,11 @@ def _is_valid_model_id_hook(): sagemaker_session=sagemaker_session, ) - if not _is_valid_model_id_hook(): + self.model_type = _validate_model_id_and_get_type_hook() + if not self.model_type: JumpStartModelsAccessor.reset_cache() - if not _is_valid_model_id_hook(): + self.model_type = _validate_model_id_and_get_type_hook() + if not self.model_type: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) hub_arn = None @@ -531,6 +533,7 @@ def _is_valid_model_id_hook(): model_id=model_id, model_version=model_version, hub_arn=hub_arn, + model_type=self.model_type, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, role=role, diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index c55c9081cb..742a6b8d3f 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -14,7 +14,12 @@ from __future__ import absolute_import from typing import List, Optional -from sagemaker.jumpstart.constants import MODEL_ID_LIST_WEB_URL, JumpStartScriptScope +from botocore.exceptions import ClientError + +from sagemaker.jumpstart.constants import ( + MODEL_ID_LIST_WEB_URL, + JumpStartScriptScope, +) NO_AVAILABLE_INSTANCES_ERROR_MSG = ( "No instances available in {region} that can support model ID '{model_id}'. " @@ -28,7 +33,7 @@ INVALID_MODEL_ID_ERROR_MSG = ( "Invalid model ID: '{model_id}'. Please visit " - f"{MODEL_ID_LIST_WEB_URL} for list of supported model IDs. " + f"{MODEL_ID_LIST_WEB_URL} for a list of valid model IDs. " "The module `sagemaker.jumpstart.notebook_utils` contains utilities for " "fetching model IDs. We recommend upgrading to the latest version of sagemaker " "to get access to the most models." @@ -58,6 +63,32 @@ def get_wildcard_model_version_msg( ) +def get_proprietary_model_subscription_msg( + model_id: str, + subscription_link: str, +) -> str: + """Returns customer-facing message for using a proprietary model.""" + + return ( + f"INFO: Using proprietary model '{model_id}'. " + f"To subscribe to this model in AWS Marketplace, see {subscription_link}" + ) + + +def get_wildcard_proprietary_model_version_msg( + model_id: str, wildcard_model_version: str, available_versions: List[str] +) -> str: + """Returns customer-facing message for passing wildcard version to proprietary models.""" + msg = ( + f"Proprietary model '{model_id}' does not support " + f"wildcard version identifier '{wildcard_model_version}'. " + ) + if len(available_versions) > 0: + msg += f"You can pin to version '{available_versions[0]}'. " + msg += f"{MODEL_ID_LIST_WEB_URL} for a list of valid model IDs. " + return msg + + def get_old_model_version_msg( model_id: str, current_model_version: str, latest_model_version: str ) -> str: @@ -70,6 +101,15 @@ def get_old_model_version_msg( ) +def get_proprietary_model_subscription_error(error: ClientError, subscription_link: str) -> None: + """Returns customer-facing message associated with a Marketplace subscription error.""" + + error_code = error.response["Error"]["Code"] + error_message = error.response["Error"]["Message"] + if error_code == "ValidationException" and "not subscribed" in error_message: + raise MarketplaceModelSubscriptionError(subscription_link) + + class JumpStartHyperparametersError(ValueError): """Exception raised for bad hyperparameters of a JumpStart model.""" @@ -169,3 +209,29 @@ def __init__( ) super().__init__(self.message) + + +class MarketplaceModelSubscriptionError(ValueError): + """Exception raised when trying to deploy a JumpStart Marketplace model. + + A caller is required to subscribe to the Marketplace product in order to deploy. + This exception is raised when a caller tries to deploy a JumpStart Marketplace model + but the caller is not subscribed to the model. + """ + + def __init__( + self, + model_subscription_link: Optional[str] = None, + message: Optional[str] = None, + ): + if message: + self.message = message + else: + self.message = ( + "To use a proprietary JumpStart model, " + "you must first subscribe to the model in AWS Marketplace. " + ) + if model_subscription_link: + self.message += f"To subscribe to this model, see {model_subscription_link}" + + super().__init__(self.message) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 71a6419f82..0074374e1a 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -50,7 +50,7 @@ TRAINING_ENTRY_POINT_SCRIPT_NAME, SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, ) -from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.factory import model from sagemaker.jumpstart.types import ( JumpStartEstimatorDeployKwargs, @@ -79,6 +79,7 @@ def get_init_kwargs( model_id: str, model_version: Optional[str] = None, hub_arn: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -137,6 +138,7 @@ def get_init_kwargs( model_id=model_id, model_version=model_version, hub_arn=hub_arn, + model_type=model_type, role=role, region=region, instance_count=instance_count, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 99ae05fc44..6f7a83cef1 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -37,7 +37,7 @@ from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines -from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.types import ( JumpStartModelDeployKwargs, JumpStartModelInitKwargs, @@ -73,6 +73,7 @@ def get_default_predictor( tolerate_vulnerable_model: bool, tolerate_deprecated_model: bool, sagemaker_session: Session, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Predictor: """Converts predictor returned from ``Model.deploy()`` into a JumpStart-specific one. @@ -95,6 +96,7 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) predictor.deserializer = deserializers.retrieve_default( model_id=model_id, @@ -104,6 +106,7 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) predictor.accept = accept_types.retrieve_default( model_id=model_id, @@ -113,6 +116,7 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) predictor.content_type = content_types.retrieve_default( model_id=model_id, @@ -122,6 +126,7 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) return predictor @@ -194,6 +199,7 @@ def _add_instance_type_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, training_instance_type=kwargs.training_instance_type, + model_type=kwargs.model_type, ) if not disable_instance_type_logging and orig_instance_type is None: @@ -206,7 +212,14 @@ def _add_instance_type_to_kwargs( def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: - """Sets image uri based on default or override, returns full kwargs.""" + """Sets image uri based on default or override, returns full kwargs. + + Uses placeholder image uri for JumpStart proprietary models that uses ModelPackages + """ + + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + kwargs.image_uri = None + return kwargs kwargs.image_uri = kwargs.image_uri or image_uris.retrieve( region=kwargs.region, @@ -227,6 +240,10 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets model data based on default or override, returns full kwargs.""" + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + kwargs.model_data = None + return kwargs + model_data: Union[str, dict] = kwargs.model_data or model_uris.retrieve( model_scope=JumpStartScriptScope.INFERENCE, model_id=kwargs.model_id, @@ -264,6 +281,10 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets source dir based on default or override, returns full kwargs.""" + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + kwargs.source_dir = None + return kwargs + source_dir = kwargs.source_dir if _model_supports_inference_script_uri( @@ -294,6 +315,10 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets entry point based on default or override, returns full kwargs.""" + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + kwargs.entry_point = None + return kwargs + entry_point = kwargs.entry_point if _model_supports_inference_script_uri( @@ -316,6 +341,10 @@ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets env based on default or override, returns full kwargs.""" + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + kwargs.env = None + return kwargs + env = kwargs.env if env is None: @@ -362,6 +391,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) kwargs.model_package_arn = model_package_arn @@ -379,6 +409,7 @@ def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelI tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) for key, value in model_kwargs_to_add.items(): @@ -415,6 +446,7 @@ def _add_endpoint_name_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) kwargs.endpoint_name = kwargs.endpoint_name or ( @@ -437,6 +469,7 @@ def _add_model_name_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) kwargs.name = kwargs.name or ( @@ -458,11 +491,12 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ).version if kwargs.sagemaker_session.settings.include_jumpstart_tags: kwargs.tags = add_jumpstart_model_id_version_tags( - kwargs.tags, kwargs.model_id, full_model_version + kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type ) if kwargs.hub_arn: @@ -483,6 +517,7 @@ def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any] tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) for key, value in deploy_kwargs_to_add.items(): @@ -503,6 +538,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, instance_type=kwargs.instance_type, ) @@ -513,6 +549,7 @@ def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, hub_arn: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -546,6 +583,7 @@ def get_deploy_kwargs( model_id=model_id, model_version=model_version, hub_arn=hub_arn, + model_type=model_type, region=region, initial_instance_count=initial_instance_count, instance_type=instance_type, @@ -685,6 +723,7 @@ def get_init_kwargs( model_from_estimator: bool = False, model_version: Optional[str] = None, hub_arn: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, instance_type: Optional[str] = None, @@ -717,6 +756,7 @@ def get_init_kwargs( model_id=model_id, model_version=model_version, hub_arn=hub_arn, + model_type=model_type, instance_type=instance_type, region=region, image_uri=image_uri, @@ -761,14 +801,12 @@ def get_init_kwargs( # we use the model artifact from the training job output if not model_from_estimator: model_init_kwargs = _add_model_data_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_source_dir_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_entry_point_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_env_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_predictor_cls_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_extra_model_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_role_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs) diff --git a/src/sagemaker/jumpstart/filters.py b/src/sagemaker/jumpstart/filters.py index b045435ed0..fc5113315d 100644 --- a/src/sagemaker/jumpstart/filters.py +++ b/src/sagemaker/jumpstart/filters.py @@ -49,6 +49,14 @@ class SpecialSupportedFilterKeys(str, Enum): TASK = "task" FRAMEWORK = "framework" + MODEL_TYPE = "model_type" + + +class ProprietaryModelFilterIdentifiers(str, Enum): + """Enum class for proprietary model filter keys.""" + + PROPRIETARY = "proprietary" + MARKETPLACE = "marketplace" FILTER_OPERATOR_STRING_MAPPINGS = { @@ -429,6 +437,22 @@ def __init__(self, key: str, value: str, operator: str): self.value = value self.operator = operator + def set_key(self, key: str) -> None: + """Sets the key for the model filter. + + Args: + key (str): The key to be set. + """ + self.key = key + + def set_value(self, value: str) -> None: + """Sets the value for the model filter. + + Args: + value (str): The value to be set. + """ + self.value = value + def parse_filter_string(filter_string: str) -> ModelFilter: """Parse filter string and return a serialized ``ModelFilter`` object. diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index f96181479a..a8b6d48fa1 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -15,15 +15,22 @@ from __future__ import absolute_import from typing import Dict, List, Optional, Union +from botocore.exceptions import ClientError + from sagemaker import payloads from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer from sagemaker.base_serializers import BaseSerializer +from sagemaker.enums import EndpointType from sagemaker.explainer.explainer_config import ExplainerConfig from sagemaker.jumpstart.accessors import JumpStartModelsAccessor from sagemaker.jumpstart.curated_hub.utils import generate_hub_arn_for_init_kwargs from sagemaker.jumpstart.enums import JumpStartScriptScope -from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG +from sagemaker.jumpstart.exceptions import ( + INVALID_MODEL_ID_ERROR_MSG, + get_proprietary_model_subscription_error, + get_proprietary_model_subscription_msg, +) from sagemaker.jumpstart.factory.model import ( get_default_predictor, get_deploy_kwargs, @@ -31,7 +38,12 @@ get_register_kwargs, ) from sagemaker.jumpstart.types import JumpStartSerializablePayload -from sagemaker.jumpstart.utils import is_valid_model_id +from sagemaker.jumpstart.utils import ( + validate_model_id_and_get_type, + verify_model_region_and_return_specs, +) +from sagemaker.jumpstart.constants import JUMPSTART_LOGGER +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.utils import stringify_object, format_tags, Tags from sagemaker.model import ( Model, @@ -46,7 +58,6 @@ from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements -from sagemaker.enums import EndpointType class JumpStartModel(Model): @@ -273,8 +284,8 @@ def __init__( ValueError: If the model ID is not recognized by JumpStart. """ - def _is_valid_model_id_hook(): - return is_valid_model_id( + def _validate_model_id_and_type(): + return validate_model_id_and_get_type( model_id=model_id, model_version=model_version, region=region, @@ -282,9 +293,11 @@ def _is_valid_model_id_hook(): sagemaker_session=sagemaker_session, ) - if not _is_valid_model_id_hook(): + self.model_type = _validate_model_id_and_type() + if not self.model_type: JumpStartModelsAccessor.reset_cache() - if not _is_valid_model_id_hook(): + self.model_type = _validate_model_id_and_type() + if not self.model_type: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) hub_arn = None @@ -294,10 +307,10 @@ def _is_valid_model_id_hook(): ) self._model_data_is_set = model_data is not None - model_init_kwargs = get_init_kwargs( model_id=model_id, model_from_estimator=False, + model_type=self.model_type, model_version=model_version, hub_arn=hub_arn, instance_type=instance_type, @@ -337,10 +350,27 @@ def _is_valid_model_id_hook(): self.region = model_init_kwargs.region self.sagemaker_session = model_init_kwargs.sagemaker_session + if self.model_type == JumpStartModelType.PROPRIETARY: + self.log_subscription_warning() + super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) self.model_package_arn = model_init_kwargs.model_package_arn + def log_subscription_warning(self) -> None: + """Log message prompting the customer to subscribe to the proprietary model.""" + subscription_link = verify_model_region_and_return_specs( + region=self.region, + model_id=self.model_id, + version=self.model_version, + model_type=self.model_type, + scope=JumpStartScriptScope.INFERENCE, + sagemaker_session=self.sagemaker_session, + ).model_subscription_link + JUMPSTART_LOGGER.warning( + get_proprietary_model_subscription_msg(self.model_id, subscription_link) + ) + def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]: """Returns all example payloads associated with the model. @@ -358,6 +388,7 @@ def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]: tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, + model_type=self.model_type, ) def retrieve_example_payload(self) -> JumpStartSerializablePayload: @@ -375,6 +406,7 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload: return payloads.retrieve_example( model_id=self.model_id, model_version=self.model_version, + model_type=self.model_type, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, @@ -569,6 +601,9 @@ def deploy( endpoint. endpoint_type (EndpointType): The type of endpoint used to deploy models. (Default: EndpointType.MODEL_BASED). + + Raises: + MarketplaceModelSubscriptionError: If the caller is not subscribed to the model. """ deploy_kwargs = get_deploy_kwargs( @@ -600,9 +635,29 @@ def deploy( resources=resources, managed_instance_scaling=managed_instance_scaling, endpoint_type=endpoint_type, + model_type=self.model_type, ) + if ( + self.model_type == JumpStartModelType.PROPRIETARY + and endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED + ): + raise ValueError( + f"{EndpointType.INFERENCE_COMPONENT_BASED} is not supported for Proprietary models." + ) - predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict()) + try: + predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict()) + except ClientError as e: + subscription_link = verify_model_region_and_return_specs( + region=self.region, + model_id=self.model_id, + version=self.model_version, + model_type=self.model_type, + scope=JumpStartScriptScope.INFERENCE, + sagemaker_session=self.sagemaker_session, + ).model_subscription_link + get_proprietary_model_subscription_error(e, subscription_link) + raise # If no predictor class was passed, add defaults to predictor if self.orig_predictor_cls is None and async_inference_config is None: @@ -615,6 +670,7 @@ def deploy( tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, + model_type=self.model_type, ) # If a predictor class was passed, do not mutate predictor diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 1554025995..485354e802 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -24,10 +24,12 @@ from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, + PROPRIETARY_MODEL_SPEC_PREFIX, ) -from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.filters import ( SPECIAL_SUPPORTED_FILTER_KEYS, + ProprietaryModelFilterIdentifiers, BooleanValues, Identity, SpecialSupportedFilterKeys, @@ -38,6 +40,7 @@ get_jumpstart_content_bucket, get_sagemaker_version, verify_model_region_and_return_specs, + validate_model_id_and_get_type, ) from sagemaker.session import Session @@ -124,15 +127,11 @@ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]: Args: model_id (str): The model ID for which to extract the framework/task/model. - - Raises: - ValueError: If the model ID cannot be parsed into at least 3 components seperated by - "-" character. """ _id_parts = model_id.split("-") if len(_id_parts) < 3: - raise ValueError(f"incorrect model ID: {model_id}.") + return "", "", "" framework = _id_parts[0] task = _id_parts[1] @@ -141,6 +140,20 @@ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]: return framework, task, name +def extract_model_type_filter_representation(spec_key: str) -> str: + """Parses model spec key, determine if the model is proprietary or open weight. + + Args: + spek_key (str): The model spec key for which to extract the model type. + """ + model_spec_prefix = spec_key.split("/")[0] + + if model_spec_prefix == PROPRIETARY_MODEL_SPEC_PREFIX: + return JumpStartModelType.PROPRIETARY.value + + return JumpStartModelType.OPEN_WEIGHTS.value + + def list_jumpstart_tasks( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), region: str = JUMPSTART_DEFAULT_REGION_NAME, @@ -321,14 +334,22 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ - models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( - region=region, s3_client=sagemaker_session.s3_client + prop_models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( + region=region, + s3_client=sagemaker_session.s3_client, + model_type=JumpStartModelType.PROPRIETARY, ) + open_weight_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( + region=region, + s3_client=sagemaker_session.s3_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + ) + models_manifest_list = open_weight_manifest_list + prop_models_manifest_list if isinstance(filter, str): filter = Identity(filter) - manifest_keys = set(models_manifest_list[0].__slots__) + manifest_keys = set(models_manifest_list[0].__slots__ + prop_models_manifest_list[0].__slots__) all_keys: Set[str] = set() @@ -338,6 +359,10 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin model_filter = operator.unresolved_value key = model_filter.key all_keys.add(key) + if model_filter.key == SpecialSupportedFilterKeys.MODEL_TYPE and model_filter.value in { + identifier.value for identifier in ProprietaryModelFilterIdentifiers + }: + model_filter.set_value(JumpStartModelType.PROPRIETARY.value) model_filters.add(model_filter) for key in all_keys: @@ -351,6 +376,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin is_task_filter = SpecialSupportedFilterKeys.TASK in all_keys is_framework_filter = SpecialSupportedFilterKeys.FRAMEWORK in all_keys + is_model_type_filter = SpecialSupportedFilterKeys.MODEL_TYPE in all_keys def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, str]]: @@ -373,6 +399,11 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, SpecialSupportedFilterKeys.FRAMEWORK ] = extract_framework_task_model(model_manifest.model_id)[0] + if is_model_type_filter: + manifest_specs_cached_values[ + SpecialSupportedFilterKeys.MODEL_TYPE + ] = extract_model_type_filter_representation(model_manifest.spec_key) + if Version(model_manifest.min_version) > Version(get_sagemaker_version()): return None @@ -466,6 +497,12 @@ def get_model_url( sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to retrieve the model url. """ + model_type = validate_model_id_and_get_type( + model_id=model_id, + model_version=model_version, + region=region, + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( region=region, @@ -473,5 +510,6 @@ def get_model_url( version=model_version, sagemaker_session=sagemaker_session, scope=JumpStartScriptScope.INFERENCE, + model_type=model_type, ) return model_specs.url diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 8bdb3eb57b..4dae2383d1 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -21,6 +21,8 @@ from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.jumpstart.enums import JumpStartModelType + from sagemaker.workflow.entities import PipelineVariable from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements @@ -120,8 +122,10 @@ def to_json(self) -> Dict[str, Any]: class JumpStartS3FileType(str, Enum): """Type of files published in JumpStart S3 distribution buckets.""" - MANIFEST = "manifest" - SPECS = "specs" + OPEN_WEIGHT_MANIFEST = "manifest" + OPEN_WEIGHT_SPECS = "specs" + PROPRIETARY_MANIFEST = "proptietary_manifest" + PROPRIETARY_SPECS = "proprietary_specs" class HubType(str, Enum): @@ -822,6 +826,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType): "training_instance_type_variants", "default_payloads", "gated_bucket", + "model_subscription_link", ] def __init__(self, spec: Dict[str, Any], is_hub_content: bool = False): @@ -842,29 +847,31 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj (Dict[str, Any]): Dictionary representation of spec. """ self.model_id: str = json_obj["model_id"] - self.url: str = json_obj["url"] + self.url: str = json_obj.get("url", "") self.version: str = json_obj["version"] self.min_sdk_version: str = json_obj["min_sdk_version"] - self.incremental_training_supported: bool = bool(json_obj["incremental_training_supported"]) + self.incremental_training_supported: bool = bool( + json_obj.get("incremental_training_supported", False) + ) self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = ( JumpStartECRSpecs(json_obj["hosting_ecr_specs"]) if "hosting_ecr_specs" in json_obj else None ) - self.hosting_artifact_key: str = json_obj["hosting_artifact_key"] - self.hosting_script_key: str = json_obj["hosting_script_key"] - self.training_supported: bool = bool(json_obj["training_supported"]) + self.hosting_artifact_key: Optional[str] = json_obj.get("hosting_artifact_key") + self.hosting_script_key: Optional[str] = json_obj.get("hosting_script_key") + self.training_supported: Optional[bool] = bool(json_obj.get("training_supported", False)) self.inference_environment_variables = [ JumpStartEnvironmentVariable(env_variable) - for env_variable in json_obj["inference_environment_variables"] + for env_variable in json_obj.get("inference_environment_variables", []) ] - self.inference_vulnerable: bool = bool(json_obj["inference_vulnerable"]) - self.inference_dependencies: List[str] = json_obj["inference_dependencies"] - self.inference_vulnerabilities: List[str] = json_obj["inference_vulnerabilities"] - self.training_vulnerable: bool = bool(json_obj["training_vulnerable"]) - self.training_dependencies: List[str] = json_obj["training_dependencies"] - self.training_vulnerabilities: List[str] = json_obj["training_vulnerabilities"] - self.deprecated: bool = bool(json_obj["deprecated"]) + self.inference_vulnerable: bool = bool(json_obj.get("inference_vulnerable", False)) + self.inference_dependencies: List[str] = json_obj.get("inference_dependencies", []) + self.inference_vulnerabilities: List[str] = json_obj.get("inference_vulnerabilities", []) + self.training_vulnerable: bool = bool(json_obj.get("training_vulnerable", False)) + self.training_dependencies: List[str] = json_obj.get("training_dependencies", []) + self.training_vulnerabilities: List[str] = json_obj.get("training_vulnerabilities", []) + self.deprecated: bool = bool(json_obj.get("deprecated", False)) self.deprecated_message: Optional[str] = json_obj.get("deprecated_message") self.deprecate_warn_message: Optional[str] = json_obj.get("deprecate_warn_message") self.usage_info_message: Optional[str] = json_obj.get("usage_info_message") @@ -954,6 +961,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if json_obj.get("training_instance_type_variants") else None ) + self.model_subscription_link = json_obj.get("model_subscription_link") def from_hub_content_doc(self, hub_content_doc: Dict[str, Any]) -> None: """Sets fields in object based on values in HubContentDocument @@ -1278,6 +1286,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "model_id", "model_version", "hub_arn", + "model_type", "instance_type", "tolerate_vulnerable_model", "tolerate_deprecated_model", @@ -1309,6 +1318,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "model_id", "model_version", "hub_arn", + "model_type", "tolerate_vulnerable_model", "tolerate_deprecated_model", "region", @@ -1321,6 +1331,7 @@ def __init__( model_id: str, model_version: Optional[str] = None, hub_arn: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, instance_type: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, @@ -1351,6 +1362,7 @@ def __init__( self.model_id = model_id self.model_version = model_version self.hub_arn = hub_arn + self.model_type = model_type self.instance_type = instance_type self.region = region self.image_uri = image_uri @@ -1384,6 +1396,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "model_id", "model_version", "hub_arn", + "model_type", "initial_instance_count", "instance_type", "region", @@ -1416,6 +1429,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "model_id", "model_version", "hub_arn", + "model_type", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1428,6 +1442,7 @@ def __init__( model_id: str, model_version: Optional[str] = None, hub_arn: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -1460,6 +1475,7 @@ def __init__( self.model_id = model_id self.model_version = model_version self.hub_arn = hub_arn + self.model_type = model_type self.initial_instance_count = initial_instance_count self.instance_type = instance_type self.region = region @@ -1495,6 +1511,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_id", "model_version", "hub_arn", + "model_type", "instance_type", "instance_count", "region", @@ -1555,6 +1572,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_id", "model_version", "hub_arn", + "model_type", } def __init__( @@ -1562,6 +1580,7 @@ def __init__( model_id: str, model_version: Optional[str] = None, hub_arn: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, role: Optional[str] = None, @@ -1619,6 +1638,7 @@ def __init__( self.model_id = model_id self.model_version = model_version self.hub_arn = hub_arn + self.model_type = model_type self.instance_type = instance_type self.instance_count = instance_count self.region = region @@ -1681,6 +1701,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "model_id", "model_version", "hub_arn", + "model_type", "region", "inputs", "wait", @@ -1696,6 +1717,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "model_id", "model_version", "hub_arn", + "model_type", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1707,6 +1729,7 @@ def __init__( model_id: str, model_version: Optional[str] = None, hub_arn: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, inputs: Optional[Union[str, Dict, Any, Any]] = None, wait: Optional[bool] = None, @@ -1722,6 +1745,7 @@ def __init__( self.model_id = model_id self.model_version = model_version self.hub_arn = hub_arn + self.model_type = model_type self.region = region self.inputs = inputs self.wait = wait diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index f7375b3027..f94f87281d 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging import os -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Set, Optional, Tuple, Union from urllib.parse import urlparse import boto3 from packaging.version import Version @@ -26,6 +26,7 @@ TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, TRAINING_JOB_ROLE_ARN_PATH, ) + from sagemaker.jumpstart import constants, enums from sagemaker.jumpstart import accessors from sagemaker.s3 import parse_s3_url @@ -314,6 +315,7 @@ def add_single_jumpstart_tag( ( tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags) or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags) + or tag_key_in_array(enums.JumpStartTag.MODEL_TYPE, curr_tags) ) if is_uri else False @@ -348,6 +350,7 @@ def add_jumpstart_model_id_version_tags( tags: Optional[List[TagsDict]], model_id: str, model_version: str, + model_type: Optional[enums.JumpStartModelType] = None, ) -> List[TagsDict]: """Add custom model ID and version tags to JumpStart related resources.""" if model_id is None or model_version is None: @@ -364,6 +367,13 @@ def add_jumpstart_model_id_version_tags( tags, is_uri=False, ) + if model_type == enums.JumpStartModelType.PROPRIETARY: + tags = add_single_jumpstart_tag( + enums.JumpStartModelType.PROPRIETARY.value, + enums.JumpStartTag.MODEL_TYPE, + tags, + is_uri=False, + ) return tags @@ -546,6 +556,7 @@ def verify_model_region_and_return_specs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """Verifies that an acceptable model_id, version, scope, and region combination is provided. @@ -597,6 +608,7 @@ def verify_model_region_and_return_specs( hub_arn=hub_arn, version=version, s3_client=sagemaker_session.s3_client, + model_type=model_type, ) if ( @@ -751,36 +763,52 @@ def resolve_estimator_sagemaker_config_field( return field_val -def is_valid_model_id( +def validate_model_id_and_get_type( model_id: Optional[str], region: Optional[str] = None, model_version: Optional[str] = None, script: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> bool: - """Returns True if the model ID is supported for the given script. +) -> Optional[enums.JumpStartModelType]: + """Returns model type if the model ID is supported for the given script. Raises: ValueError: If the script is not supported by JumpStart. """ + + def _get_model_type( + model_id: str, + open_weights_model_ids: Set[str], + proprietary_model_ids: Set[str], + script: enums.JumpStartScriptScope, + ) -> Optional[enums.JumpStartModelType]: + if model_id in open_weights_model_ids: + return enums.JumpStartModelType.OPEN_WEIGHTS + if model_id in proprietary_model_ids: + if script == enums.JumpStartScriptScope.INFERENCE: + return enums.JumpStartModelType.PROPRIETARY + raise ValueError(f"Unsupported script for Marketplace models: {script}") + return None + if model_id in {None, ""}: - return False + return None if not isinstance(model_id, str): - return False + return None s3_client = sagemaker_session.s3_client if sagemaker_session else None region = region or constants.JUMPSTART_DEFAULT_REGION_NAME model_version = model_version or "*" - models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( - region=region, s3_client=s3_client + region=region, s3_client=s3_client, model_type=enums.JumpStartModelType.OPEN_WEIGHTS + ) + open_weight_model_id_set = {model.model_id for model in models_manifest_list} + + proprietary_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( + region=region, s3_client=s3_client, model_type=enums.JumpStartModelType.PROPRIETARY ) - model_id_set = {model.model_id for model in models_manifest_list} - if script == enums.JumpStartScriptScope.INFERENCE: - return model_id in model_id_set - if script == enums.JumpStartScriptScope.TRAINING: - return model_id in model_id_set - raise ValueError(f"Unsupported script: {script}") + + proprietary_model_id_set = {model.model_id for model in proprietary_manifest_list} + return _get_model_type(model_id, open_weight_model_id_set, proprietary_model_id_set, script) def get_jumpstart_model_id_version_from_resource_arn( diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index af08d1203f..c5f8b86f60 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -141,7 +141,7 @@ class Model(ModelBase, InferenceRecommenderMixin): def __init__( self, - image_uri: Union[str, PipelineVariable], + image_uri: Optional[Union[str, PipelineVariable]] = None, model_data: Optional[Union[str, PipelineVariable, dict]] = None, role: Optional[str] = None, predictor_cls: Optional[callable] = None, diff --git a/src/sagemaker/payloads.py b/src/sagemaker/payloads.py index 52d633ed4e..de33f61b82 100644 --- a/src/sagemaker/payloads.py +++ b/src/sagemaker/payloads.py @@ -21,6 +21,7 @@ from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.payload_utils import PayloadSerializer from sagemaker.jumpstart.types import JumpStartSerializablePayload +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -31,6 +32,7 @@ def retrieve_all_examples( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, serialize: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -83,6 +85,7 @@ def retrieve_all_examples( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) if unserialized_payload_dict is None: @@ -120,6 +123,7 @@ def retrieve_example( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, serialize: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -133,6 +137,8 @@ def retrieve_example( the model payload. model_version (str): The version of the JumpStart model for which to retrieve the model payload. + model_type (str): The model type of the JumpStart model, either is open weight + or proprietary. serialize (bool): Whether to serialize byte-stream valued payloads by downloading binary files from s3 and applying encoding, or to keep payload in pre-serialized state. Set this option to False if you want to avoid s3 downloads or if you @@ -162,6 +168,7 @@ def retrieve_example( region=region, model_id=model_id, model_version=model_version, + model_type=model_type, serialize=serialize, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index aaa1c1d797..d3f41bd9a6 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -15,6 +15,7 @@ from typing import Optional from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.factory.model import get_default_predictor from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint @@ -42,6 +43,7 @@ def retrieve_default( hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Predictor: """Retrieves the default predictor for the model matching the given arguments. @@ -111,4 +113,5 @@ def retrieve_default( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) diff --git a/src/sagemaker/resource_requirements.py b/src/sagemaker/resource_requirements.py index 62389ba127..eac9dfa5b1 100644 --- a/src/sagemaker/resource_requirements.py +++ b/src/sagemaker/resource_requirements.py @@ -21,6 +21,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session LOGGER = logging.getLogger("sagemaker") @@ -34,6 +35,7 @@ def retrieve_default( scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, ) -> ResourceRequirements: @@ -86,6 +88,7 @@ def retrieve_default( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + model_type=model_type, sagemaker_session=sagemaker_session, instance_type=instance_type, ) diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index 43b5a9fa34..0b12e707f4 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -34,6 +34,7 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -98,6 +99,7 @@ def retrieve_default( hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> BaseSerializer: """Retrieves the default serializer for the model matching the given arguments. @@ -142,4 +144,5 @@ def retrieve_default( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 24050807cc..5205765e2f 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -256,9 +256,33 @@ def test_jumpstart_model_register(setup): predictor = model_package.deploy( instance_type="ml.p3.2xlarge", initial_instance_count=1, - tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], ) response = predictor.predict("hello world!") assert response is not None + + +@pytest.mark.skipif( + True, + reason="Only enable if test account is subscribed to the proprietary model", +) +def test_proprietary_jumpstart_model(setup): + + model_id = "ai21-jurassic-2-light" + + model = JumpStartModel( + model_id=model_id, + model_version="2.0.004", + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + + predictor = model.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}] + ) + payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1} + + response = predictor.predict(payload) + + assert response is not None diff --git a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py index 4284ac4d84..d05d765d25 100644 --- a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py +++ b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py @@ -17,21 +17,27 @@ from sagemaker import accept_types from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec + mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_accept_types( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -45,18 +51,27 @@ def test_jumpstart_default_accept_types( assert default_accept_type == "application/json" patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_supported_accept_types( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -73,5 +88,10 @@ def test_jumpstart_supported_accept_types( ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) diff --git a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py index c924417946..6d20ca2169 100644 --- a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py +++ b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py @@ -17,6 +17,7 @@ from sagemaker import content_types from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec @@ -24,14 +25,18 @@ mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_content_types( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -45,18 +50,27 @@ def test_jumpstart_default_content_types( assert default_content_type == "application/x-text" patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_supported_content_types( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -72,5 +86,10 @@ def test_jumpstart_supported_content_types( ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) diff --git a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py index 4807fc7933..0a4cfa7db6 100644 --- a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py +++ b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py @@ -18,6 +18,7 @@ from sagemaker import base_deserializers, deserializers from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec @@ -26,14 +27,18 @@ mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_deserializers( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -47,18 +52,27 @@ def test_jumpstart_default_deserializers( assert isinstance(default_deserializer, base_deserializers.JSONDeserializer) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_deserializer_options( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -79,5 +93,10 @@ def test_jumpstart_deserializer_options( ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) diff --git a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py index d89687b3f4..d1abc2bb68 100644 --- a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py +++ b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py @@ -18,17 +18,23 @@ import pytest from sagemaker import environment_variables +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec + mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_default_environment_variables(patched_get_model_specs): +def test_jumpstart_default_environment_variables( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id = "pytorch-eqa-bert-base-cased" region = "us-west-2" @@ -48,7 +54,12 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="*", s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version="*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -68,7 +79,12 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="1.*", s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version="1.*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -98,10 +114,14 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): ) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_sdk_environment_variables(patched_get_model_specs): +def test_jumpstart_sdk_environment_variables( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id = "pytorch-eqa-bert-base-cased" region = "us-west-2" @@ -122,7 +142,12 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="*", s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version="*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -143,7 +168,12 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="1.*", s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version="1.*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py index ba08aa0825..f2f9d13939 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py @@ -18,6 +18,7 @@ import pytest from sagemaker import hyperparameters +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec @@ -26,10 +27,14 @@ mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_default_hyperparameters(patched_get_model_specs): +def test_jumpstart_default_hyperparameters( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id = "pytorch-eqa-bert-base-cased" region = "us-west-2" @@ -43,7 +48,12 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): assert params == {"adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3"} patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="*", s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version="*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -57,7 +67,12 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): assert params == {"adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3"} patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="1.*", s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version="1.*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -79,7 +94,12 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="1.*", s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version="1.*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index ae8138b7c5..0f69cb572a 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -17,7 +17,7 @@ import pytest import boto3 from sagemaker import hyperparameters -from sagemaker.jumpstart.enums import HyperparameterValidationMode +from sagemaker.jumpstart.enums import HyperparameterValidationMode, JumpStartModelType from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError from sagemaker.jumpstart.types import JumpStartHyperparameter @@ -27,8 +27,11 @@ mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_validate_provided_hyperparameters(patched_get_model_specs): +def test_jumpstart_validate_provided_hyperparameters( + patched_get_model_specs, patched_validate_model_id_and_get_type +): def add_options_to_hyperparameter(*largs, **kwargs): spec = get_spec_from_base_spec(*largs, **kwargs) spec.hyperparameters.extend( @@ -109,6 +112,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): return spec patched_get_model_specs.side_effect = add_options_to_hyperparameter + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "pytorch-eqa-bert-base-cased", "*" region = "us-west-2" @@ -136,7 +140,12 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -395,8 +404,11 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_validate_algorithm_hyperparameters(patched_get_model_specs): +def test_jumpstart_validate_algorithm_hyperparameters( + patched_get_model_specs, patched_validate_model_id_and_get_type +): def add_options_to_hyperparameter(*largs, **kwargs): spec = get_spec_from_base_spec(*largs, **kwargs) spec.hyperparameters.append( @@ -413,6 +425,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): return spec patched_get_model_specs.side_effect = add_options_to_hyperparameter + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "pytorch-eqa-bert-base-cased", "*" region = "us-west-2" @@ -434,7 +447,12 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -461,10 +479,14 @@ def add_options_to_hyperparameter(*largs, **kwargs): assert str(e.value) == "Cannot find algorithm hyperparameter for 'adam-learning-rate'." +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs): +def test_jumpstart_validate_all_hyperparameters( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "pytorch-eqa-bert-base-cased", "*" region = "us-west-2" @@ -488,7 +510,12 @@ def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs): ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index b37b770875..bd4383499d 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -18,19 +18,24 @@ from sagemaker import image_uris from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.image_uris.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_common_image_uri( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -50,6 +55,7 @@ def test_jumpstart_common_image_uri( version="*", s3_client=mock_client, hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -71,6 +77,7 @@ def test_jumpstart_common_image_uri( version="1.*", s3_client=mock_client, hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -92,6 +99,7 @@ def test_jumpstart_common_image_uri( version="*", s3_client=mock_client, hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -113,6 +121,7 @@ def test_jumpstart_common_image_uri( version="1.*", s3_client=mock_client, hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index 50f35cb872..ea6835bec3 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -18,14 +18,17 @@ import pytest from sagemaker import instance_types +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_instance_types(patched_get_model_specs): +def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_model_id_and_get_type): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "huggingface-eqa-bert-base-cased", "*" region = "us-west-2" @@ -48,6 +51,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): version=model_version, hub_arn=None, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -67,6 +71,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): version=model_version, hub_arn=None, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -92,6 +97,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): version=model_version, hub_arn=None, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -119,6 +125,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): version=model_version, hub_arn=None, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 43777ab14a..cb0f4831b7 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -6113,6 +6113,7 @@ "deprecated_message": None, "hosting_model_package_arns": None, "hosting_eula_key": None, + "model_subscription_link": None, "hyperparameters": [ { "name": "epochs", @@ -6309,3 +6310,83 @@ "imagenet-inception-v3-classification-4/specs_v3.0.0.json", }, ] + +BASE_PROPRIETARY_HEADER = { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], +} + +BASE_PROPRIETARY_MANIFEST = [ + { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], + }, + { + "model_id": "lighton-mini-instruct40b", + "version": "v1.0", + "min_version": "2.0.0", + "spec_key": "proprietary-models/lighton-mini-instruct40b/proprietary_specs_v1.0.json", + "search_keywords": ["Text2Text", "Generation"], + }, + { + "model_id": "ai21-paraphrase", + "version": "1.0.005", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-paraphrase/proprietary_specs_1.0.005.json", + "search_keywords": ["Text2Text", "Generation"], + }, +] + +BASE_PROPRIETARY_SPEC = { + "model_id": "ai21-jurassic-2-light", + "version": "2.0.004", + "min_sdk_version": "2.999.0", + "listing_id": "prodview-roz6zicyvi666", + "product_id": "1bd680a0-f29b-479d-91c3-9899743021cf", + "model_subscription_link": "https://aws.amazon.com/marketplace/ai/procurement?productId=1bd680a0", + "hosting_notebook_key": "pmm-notebooks/pmm-notebook-ai21-jurassic-2-light.ipynb", + "deploy_kwargs": { + "model_data_download_timeout": 3600, + "container_startup_health_check_timeout": 600, + }, + "default_payloads": { + "Shakespeare": { + "content_type": "application/json", + "prompt_key": "prompt", + "output_keys": {"generated_text": "[0].completions[0].data.text"}, + "body": {"prompt": "To be, or", "maxTokens": 1, "temperature": 0}, + } + }, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "default_inference_instance_type": "ml.p4de.24xlarge", + "supported_inference_instance_types": ["ml.p4de.24xlarge"], + "hosting_model_package_arns": { + "us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/j2-light-v2-0-004", + "us-east-2": "arn:aws:sagemaker:us-east-2:057799348421:model-package/j2-light-v2-0-004", + "us-west-1": "arn:aws:sagemaker:us-west-1:382657785993:model-package/j2-light-v2-0-004", + "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/j2-light-v2-0-004", + "ca-central-1": "arn:aws:sagemaker:ca-central-1:470592106596:model-package/j2-light-v2-0-004", + "eu-central-1": "arn:aws:sagemaker:eu-central-1:446921602837:model-package/j2-light-v2-0-004", + "eu-west-1": "arn:aws:sagemaker:eu-west-1:985815980388:model-package/j2-light-v2-0-004", + "eu-west-2": "arn:aws:sagemaker:eu-west-2:856760150666:model-package/j2-light-v2-0-004", + "eu-west-3": "arn:aws:sagemaker:eu-west-3:843114510376:model-package/j2-light-v2-0-004", + "eu-north-1": "arn:aws:sagemaker:eu-north-1:136758871317:model-package/j2-light-v2-0-004", + "ap-southeast-1": "arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/j2-light-v2-0-004", + "ap-southeast-2": "arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/j2-light-v2-0-004", + "ap-northeast-2": "arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/j2-light-v2-0-004", + "ap-northeast-1": "arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/j2-light-v2-0-004", + "ap-south-1": "arn:aws:sagemaker:ap-south-1:077584701553:model-package/j2-light-v2-0-004", + "sa-east-1": "arn:aws:sagemaker:sa-east-1:270155090741:model-package/j2-light-v2-0-004", + }, +} diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index d88961ebb7..f381573fe8 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -31,7 +31,7 @@ _retrieve_default_training_metric_definitions, ) from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag, JumpStartModelType from sagemaker.jumpstart.estimator import JumpStartEstimator @@ -61,9 +61,10 @@ class EstimatorTest(unittest.TestCase): @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER") @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -76,14 +77,15 @@ def test_non_prepacked( mock_estimator_fit: mock.Mock, mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, + mock_get_model_type: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, mock_jumpstart_model_factory_logger: mock.Mock, mock_jumpstart_estimator_factory_logger: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_sagemaker_timestamp.return_value = "9876" @@ -93,6 +95,8 @@ def test_non_prepacked( mock_get_model_specs.side_effect = get_special_model_spec + mock_get_model_type.return_value = JumpStartModelType.OPEN_WEIGHTS + mock_session_estimator.return_value = sagemaker_session mock_session_model.return_value = sagemaker_session @@ -183,7 +187,7 @@ def test_non_prepacked( ], ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -200,11 +204,11 @@ def test_prepacked( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -423,7 +427,7 @@ def test_hub_model( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -440,14 +444,14 @@ def test_gated_model_s3_uri( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_timestamp: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor mock_timestamp.return_value = "8675309" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-gated-artifact-trainable-model", "*" @@ -559,7 +563,7 @@ def test_gated_model_s3_uri( "sagemaker.jumpstart.artifacts.environment_variables.get_jumpstart_gated_content_bucket" ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -576,7 +580,7 @@ def test_gated_model_non_model_package_s3_uri( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_timestamp: mock.Mock, mock_get_jumpstart_gated_content_bucket: mock.Mock, ): @@ -585,7 +589,7 @@ def test_gated_model_non_model_package_s3_uri( mock_get_jumpstart_gated_content_bucket.return_value = "top-secret-private-models-bucket" mock_timestamp.return_value = "8675309" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-gated-artifact-non-model-package-trainable-model", "*" @@ -707,7 +711,7 @@ def test_gated_model_non_model_package_s3_uri( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -724,15 +728,13 @@ def test_jumpstart_model_package_artifact_s3_uri_unsupported_region( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_timestamp: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor mock_timestamp.return_value = "8675309" - mock_is_valid_model_id.return_value = True - model_id, _ = "js-gated-artifact-trainable-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -740,6 +742,8 @@ def test_jumpstart_model_package_artifact_s3_uri_unsupported_region( mock_session_estimator.return_value = sagemaker_session mock_session_model.return_value = sagemaker_session + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS + with pytest.raises(ValueError) as e: JumpStartEstimator(model_id=model_id, region="eu-north-1") @@ -749,7 +753,7 @@ def test_jumpstart_model_package_artifact_s3_uri_unsupported_region( "us-west-2, us-east-1, eu-west-1, ap-southeast-1." ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -762,10 +766,10 @@ def test_deprecated( mock_estimator_fit: mock.Mock, mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "deprecated_model", "*" @@ -783,7 +787,7 @@ def test_deprecated( JumpStartEstimator(model_id=model_id, tolerate_deprecated_model=True).fit(channels).deploy() - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -796,9 +800,9 @@ def test_vulnerable( mock_estimator_fit: mock.Mock, mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "vulnerable_model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -899,7 +903,7 @@ def test_estimator_use_kwargs(self): @mock.patch("sagemaker.jumpstart.factory.estimator.metric_definitions_utils.retrieve_default") @mock.patch("sagemaker.jumpstart.factory.estimator.environment_variables.retrieve_default") @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -916,7 +920,7 @@ def evaluate_estimator_workflow_with_kwargs( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_timestamp: mock.Mock, mock_retrieve_default_environment_variables: mock.Mock, mock_retrieve_metric_definitions: mock.Mock, @@ -947,7 +951,7 @@ def evaluate_estimator_workflow_with_kwargs( mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "js-trainable-model", "*" @@ -1049,16 +1053,16 @@ def evaluate_estimator_workflow_with_kwargs( mock_estimator_deploy.assert_called_once_with(**expected_deploy_kwargs) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_estimator_tags_disabled( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -1088,16 +1092,16 @@ def test_jumpstart_estimator_tags_disabled( [{"Key": "blah", "Value": "blahagain"}], ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_estimator_tags( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -1130,18 +1134,18 @@ def test_jumpstart_estimator_tags( @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_estimator_attach_no_model_id_happy_case( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, get_model_id_version_from_training_job: mock.Mock, mock_attach: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS get_model_id_version_from_training_job.return_value = ( "js-trainable-model-prepacked", @@ -1173,18 +1177,18 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_estimator_attach_no_model_id_sad_case( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, get_model_id_version_from_training_job: mock.Mock, mock_attach: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS get_model_id_version_from_training_job.side_effect = ValueError() @@ -1257,22 +1261,22 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): @mock.patch("sagemaker.jumpstart.estimator.get_init_kwargs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") - def test_is_valid_model_id( + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") + def test_validate_model_id_and_get_type( self, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_init: mock.Mock, mock_get_init_kwargs: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS JumpStartEstimator(model_id="valid_model_id") - mock_is_valid_model_id.return_value = False + mock_validate_model_id_and_get_type.return_value = False with pytest.raises(ValueError): JumpStartEstimator(model_id="invalid_model_id") @mock.patch("sagemaker.jumpstart.estimator.get_default_predictor") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1289,14 +1293,14 @@ def test_no_predictor_returns_default_predictor( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor mock_get_default_predictor.return_value = default_predictor_with_presets - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -1332,7 +1336,7 @@ def test_no_predictor_returns_default_predictor( self.assertEqual(predictor, default_predictor_with_presets) @mock.patch("sagemaker.jumpstart.estimator.get_default_predictor") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1349,14 +1353,14 @@ def test_no_predictor_yes_async_inference_config( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor mock_get_default_predictor.return_value = default_predictor_with_presets - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -1382,7 +1386,7 @@ def test_no_predictor_yes_async_inference_config( self.assertEqual(type(predictor), Predictor) @mock.patch("sagemaker.jumpstart.estimator.get_default_predictor") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1399,14 +1403,14 @@ def test_yes_predictor_returns_unmodified_predictor( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor mock_get_default_predictor.return_value = default_predictor_with_presets - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -1432,7 +1436,7 @@ def test_yes_predictor_returns_unmodified_predictor( self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.estimator._model_supports_incremental_training") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -1453,9 +1457,9 @@ def test_incremental_training_with_unsupported_model_logs_warning( mock_session_model: mock.Mock, mock_logger_warning: mock.Mock, mock_supports_incremental_training: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_estimator_deploy.return_value = default_predictor @@ -1487,7 +1491,7 @@ def test_incremental_training_with_unsupported_model_logs_warning( sagemaker_session=sagemaker_session, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.estimator._model_supports_incremental_training") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -1508,9 +1512,9 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( mock_session_model: mock.Mock, mock_logger_warning: mock.Mock, mock_supports_incremental_training: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_estimator_deploy.return_value = default_predictor @@ -1540,7 +1544,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1557,10 +1561,10 @@ def test_estimator_sets_different_inference_instance_depending_on_training_insta mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_sagemaker_timestamp.return_value = "3456" @@ -1601,7 +1605,7 @@ def test_estimator_sets_different_inference_instance_depending_on_training_insta assert mock_estimator_deploy.call_args[1]["instance_type"] == "ml.p4de.24xlarge" @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1618,10 +1622,10 @@ def test_training_passes_role_to_deploy( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_sagemaker_timestamp.return_value = "3456" @@ -1678,7 +1682,7 @@ def test_training_passes_role_to_deploy( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch( "sagemaker.jumpstart.factory.model.DEFAULT_JUMPSTART_SAGEMAKER_SESSION", sagemaker_session ) @@ -1698,10 +1702,10 @@ def test_training_passes_session_to_deploy( mock_estimator_fit: mock.Mock, mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_sagemaker_timestamp.return_value = "3456" @@ -1756,7 +1760,7 @@ def test_training_passes_session_to_deploy( ], ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.estimator._retrieve_estimator_init_kwargs") @@ -1772,11 +1776,11 @@ def test_model_id_not_found_refeshes_cache_training( mock_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.side_effect = [False, False] + mock_validate_model_id_and_get_type.side_effect = [False, False] model_id, _ = "js-trainable-model", "*" @@ -1792,7 +1796,7 @@ def test_model_id_not_found_refeshes_cache_training( ) mock_reset_cache.assert_called_once_with() - mock_is_valid_model_id.assert_has_calls( + mock_validate_model_id_and_get_type.assert_has_calls( calls=[ mock.call( model_id="js-trainable-model", @@ -1811,16 +1815,16 @@ def test_model_id_not_found_refeshes_cache_training( ] ) - mock_is_valid_model_id.reset_mock() + mock_validate_model_id_and_get_type.reset_mock() mock_reset_cache.reset_mock() - mock_is_valid_model_id.side_effect = [False, True] + mock_validate_model_id_and_get_type.side_effect = [False, True] JumpStartEstimator( model_id=model_id, ) mock_reset_cache.assert_called_once_with() - mock_is_valid_model_id.assert_has_calls( + mock_validate_model_id_and_get_type.assert_has_calls( calls=[ mock.call( model_id="js-trainable-model", @@ -1839,7 +1843,7 @@ def test_model_id_not_found_refeshes_cache_training( ] ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @@ -1849,10 +1853,10 @@ def test_model_artifact_variant_estimator( mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "model-artifact-variant-model", "*" diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py index d22e910a00..073921d5ba 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py @@ -25,6 +25,7 @@ from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.estimator import JumpStartEstimator +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session from sagemaker.utils import resolve_value_from_config @@ -78,7 +79,7 @@ def config_value_impl(sagemaker_session: Session, config_path: str, sagemaker_co class IntelligentDefaultsEstimatorTest(unittest.TestCase): - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -98,12 +99,12 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( mock_retrieve_model_init_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -135,7 +136,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( assert "enable_network_isolation" not in mock_estimator_deploy.call_args[1] - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -155,12 +156,12 @@ def test_without_arg_overwrites_with_kwarg_collisions_with_config( mock_model_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -208,7 +209,7 @@ def test_without_arg_overwrites_with_kwarg_collisions_with_config( config_inference_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -228,12 +229,12 @@ def test_with_arg_overwrites_with_kwarg_collisions_with_config( mock_model_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -290,7 +291,7 @@ def test_with_arg_overwrites_with_kwarg_collisions_with_config( override_inference_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -310,12 +311,12 @@ def test_with_arg_overwrites_without_kwarg_collisions_with_config( mock_model_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -365,7 +366,7 @@ def test_with_arg_overwrites_without_kwarg_collisions_with_config( override_inference_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.session.Session.get_caller_identity_arn") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @@ -387,12 +388,12 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, mock_get_caller_identity_arn: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -426,7 +427,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( assert "enable_network_isolation" not in mock_estimator_deploy.call_args[1] - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.session.Session.get_caller_identity_arn") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @@ -448,12 +449,12 @@ def test_without_arg_overwrites_with_kwarg_collisions_without_config( mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, mock_get_caller_identity_arn: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_get_caller_identity_arn.return_value = execution_role model_id, _ = "js-trainable-model", "*" @@ -500,7 +501,7 @@ def test_without_arg_overwrites_with_kwarg_collisions_without_config( metadata_inference_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -520,11 +521,11 @@ def test_with_arg_overwrites_with_kwarg_collisions_without_config( mock_model_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -576,7 +577,7 @@ def test_with_arg_overwrites_with_kwarg_collisions_without_config( override_inference_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.estimator._retrieve_estimator_init_kwargs") @@ -594,11 +595,11 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index e7c00887fd..9d076b6b93 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -22,7 +22,7 @@ _retrieve_default_environment_variables, ) from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag, JumpStartModelType from sagemaker.jumpstart.model import JumpStartModel from sagemaker.model import Model @@ -32,9 +32,11 @@ from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from tests.unit.sagemaker.jumpstart.utils import ( + get_spec_from_base_spec, get_special_model_spec, overwrite_dictionary, get_special_model_spec_for_inference_component_based_endpoint, + get_prototype_manifest, ) execution_role = "fake role! do not use!" @@ -53,7 +55,7 @@ class ModelTest(unittest.TestCase): @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER") @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -65,7 +67,7 @@ def test_non_prepacked( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, mock_jumpstart_model_factory_logger: mock.Mock, ): @@ -73,7 +75,7 @@ def test_non_prepacked( mock_sagemaker_timestamp.return_value = "7777" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -128,7 +130,7 @@ def test_non_prepacked( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -140,14 +142,14 @@ def test_non_prepacked_inference_component_based_endpoint( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): mock_model_deploy.return_value = default_predictor mock_sagemaker_timestamp.return_value = "7777" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" mock_get_model_specs.side_effect = ( @@ -208,7 +210,7 @@ def test_non_prepacked_inference_component_based_endpoint( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -220,14 +222,14 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): mock_model_deploy.return_value = default_predictor mock_sagemaker_timestamp.return_value = "7777" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-class-model-prepacked", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -282,7 +284,7 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -294,11 +296,11 @@ def test_prepacked( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-class-model-prepacked", "*" @@ -345,7 +347,7 @@ def test_prepacked( @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") @mock.patch("sagemaker.session.Session.create_model") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -353,7 +355,7 @@ def test_no_compiled_model_warning_log_js_models( self, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_create_model: mock.Mock, mock_endpoint_from_production_variants: mock.Mock, mock_timestamp: mock.Mock, @@ -362,7 +364,7 @@ def test_no_compiled_model_warning_log_js_models( mock_timestamp.return_value = "1234" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "gated_llama_neuron_model", "*" @@ -381,7 +383,7 @@ def test_no_compiled_model_warning_log_js_models( @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") @mock.patch("sagemaker.session.Session.create_model") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -389,15 +391,14 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( self, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_create_model: mock.Mock, mock_endpoint_from_production_variants: mock.Mock, mock_timestamp: mock.Mock, ): - mock_timestamp.return_value = "1234" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "gated_variant-model", "*" @@ -441,7 +442,64 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( ], ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.__init__") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_proprietary_model_endpoint( + self, + mock_model_deploy: mock.Mock, + mock_model_init: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, + mock_sagemaker_timestamp: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + mock_sagemaker_timestamp.return_value = "7777" + + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.PROPRIETARY + model_id, _ = "ai21-summarization", "*" + + mock_get_model_specs.side_effect = get_spec_from_base_spec + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id, model_version="2.0.004") + + mock_model_init.assert_called_once_with( + predictor_cls=Predictor, + role=execution_role, + sagemaker_session=sagemaker_session, + enable_network_isolation=False, + ) + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.p4de.24xlarge", + wait=True, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "ai21-summarization"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "2.0.004"}, + {"Key": JumpStartTag.MODEL_TYPE, "Value": "proprietary"}, + ], + endpoint_logging=False, + model_data_download_timeout=3600, + container_startup_health_check_timeout=600, + ) + + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -451,11 +509,11 @@ def test_deprecated( mock_model_deploy: mock.Mock, mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "deprecated_model", "*" @@ -468,7 +526,7 @@ def test_deprecated( JumpStartModel(model_id=model_id, tolerate_deprecated_model=True).deploy() - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -478,9 +536,9 @@ def test_vulnerable( mock_model_deploy: mock.Mock, mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_model_deploy.return_value = default_predictor @@ -543,7 +601,7 @@ def test_model_use_kwargs(self): ) @mock.patch("sagemaker.jumpstart.factory.model.environment_variables.retrieve_default") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -555,7 +613,7 @@ def evaluate_model_workflow_with_kwargs( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_retrieve_environment_variables: mock.Mock, init_kwargs: Optional[dict] = None, deploy_kwargs: Optional[dict] = None, @@ -565,7 +623,7 @@ def evaluate_model_workflow_with_kwargs( mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_session.return_value = sagemaker_session @@ -662,22 +720,22 @@ def test_jumpstart_model_kwargs_match_parent_class(self): @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") - def test_is_valid_model_id( + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") + def test_validate_model_id_and_get_type( self, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_init: mock.Mock, mock_get_init_kwargs: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS JumpStartModel(model_id="valid_model_id") - mock_is_valid_model_id.return_value = False + mock_validate_model_id_and_get_type.return_value = False with pytest.raises(ValueError): JumpStartModel(model_id="invalid_model_id") @mock.patch("sagemaker.jumpstart.model.get_default_predictor") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -689,14 +747,14 @@ def test_no_predictor_returns_default_predictor( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-class-model-prepacked", "*" @@ -719,12 +777,13 @@ def test_no_predictor_returns_default_predictor( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=model.sagemaker_session, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -736,14 +795,14 @@ def test_no_predictor_yes_async_inference_config( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-class-model-prepacked", "*" @@ -760,7 +819,7 @@ def test_no_predictor_yes_async_inference_config( mock_get_default_predictor.assert_not_called() @mock.patch("sagemaker.jumpstart.model.get_default_predictor") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -772,14 +831,14 @@ def test_yes_predictor_returns_default_predictor( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-class-model-prepacked", "*" @@ -795,24 +854,24 @@ def test_yes_predictor_returns_default_predictor( self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.JumpStartModelsAccessor.reset_cache") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) - def test_model_id_not_found_refeshes_cach_inference( + def test_model_id_not_found_refeshes_cache_inference( self, mock_reset_cache: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.side_effect = [False, False] + mock_validate_model_id_and_get_type.side_effect = [False, False] model_id, _ = "js-trainable-model", "*" @@ -828,7 +887,7 @@ def test_model_id_not_found_refeshes_cach_inference( ) mock_reset_cache.assert_called_once_with() - mock_is_valid_model_id.assert_has_calls( + mock_validate_model_id_and_get_type.assert_has_calls( calls=[ mock.call( model_id="js-trainable-model", @@ -847,16 +906,19 @@ def test_model_id_not_found_refeshes_cach_inference( ] ) - mock_is_valid_model_id.reset_mock() + mock_validate_model_id_and_get_type.reset_mock() mock_reset_cache.reset_mock() - mock_is_valid_model_id.side_effect = [False, True] + mock_validate_model_id_and_get_type.side_effect = [ + False, + JumpStartModelType.OPEN_WEIGHTS, + ] JumpStartModel( model_id=model_id, ) mock_reset_cache.assert_called_once_with() - mock_is_valid_model_id.assert_has_calls( + mock_validate_model_id_and_get_type.assert_has_calls( calls=[ mock.call( model_id="js-trainable-model", @@ -875,16 +937,16 @@ def test_model_id_not_found_refeshes_cach_inference( ] ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_tags( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "env-var-variant-model", "*" @@ -911,16 +973,16 @@ def test_jumpstart_model_tags( [{"Key": "blah", "Value": "blahagain"}] + js_tags, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_tags_disabled( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "env-var-variant-model", "*" @@ -943,16 +1005,16 @@ def test_jumpstart_model_tags_disabled( [{"Key": "blah", "Value": "blahagain"}], ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_package_arn( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-package-arn", "*" @@ -977,16 +1039,16 @@ def test_jumpstart_model_package_arn( self.assertIn(tag, mock_session.create_model.call_args[1]["tags"]) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_package_arn_override( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS # arbitrary model without model packarn arn model_id, _ = "js-trainable-model", "*" @@ -1019,7 +1081,7 @@ def test_jumpstart_model_package_arn_override( }, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1027,10 +1089,10 @@ def test_jumpstart_model_package_arn_unsupported_region( self, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-package-arn", "*" @@ -1046,7 +1108,7 @@ def test_jumpstart_model_package_arn_unsupported_region( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -1060,14 +1122,14 @@ def test_model_data_s3_prefix_override( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): mock_model_deploy.return_value = default_predictor mock_sagemaker_timestamp.return_value = "7777" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -1111,7 +1173,7 @@ def test_model_data_s3_prefix_override( '"S3DataType": "S3Prefix", "CompressionType": "None"}}', ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -1125,11 +1187,11 @@ def test_model_data_s3_prefix_model( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "model_data_s3_prefix_model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -1155,7 +1217,7 @@ def test_model_data_s3_prefix_model( mock_js_info_logger.assert_not_called() - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -1169,11 +1231,11 @@ def test_model_artifact_variant_model( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "model-artifact-variant-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -1220,7 +1282,7 @@ def test_model_artifact_variant_model( enable_network_isolation=True, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -1232,11 +1294,11 @@ def test_model_registry_accept_and_response_types( mock_model_deploy: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "model_data_s3_prefix_model", "*" mock_get_model_specs.side_effect = get_special_model_spec diff --git a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py index 727f3060b3..70409704e6 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py @@ -21,6 +21,7 @@ from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.model import JumpStartModel +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session from sagemaker.utils import resolve_value_from_config @@ -59,7 +60,7 @@ class IntelligentDefaultsModelTest(unittest.TestCase): region = "us-west-2" sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -75,10 +76,10 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" mock_retrieve_kwargs.return_value = {} @@ -100,7 +101,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( assert "enable_network_isolation" not in mock_model_init.call_args[1] - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -116,10 +117,10 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -146,7 +147,7 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( override_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -162,10 +163,10 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -192,7 +193,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( config_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -208,10 +209,10 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -240,7 +241,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( override_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -256,10 +257,10 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -286,7 +287,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( metadata_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -302,9 +303,9 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -333,7 +334,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( override_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -349,10 +350,10 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -374,7 +375,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( self.assertEquals(mock_model_init.call_args[1].get("role"), execution_role) assert "enable_network_isolation" not in mock_model_init.call_args[1] - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -390,10 +391,10 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index 460494e116..fb22287909 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -18,6 +18,7 @@ import pytest from sagemaker.jumpstart import accessors +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.constants import BASE_MANIFEST from tests.unit.sagemaker.jumpstart.utils import ( get_header_from_base_header, @@ -97,12 +98,51 @@ def test_jumpstart_models_cache_get_model_specs(mock_cache): ) ) +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") +def test_jumpstart_proprietary_models_cache_get(mock_cache): + + mock_cache.get_manifest = Mock(return_value=BASE_MANIFEST) + mock_cache.get_header = Mock(side_effect=get_header_from_base_header) + mock_cache.get_specs = Mock(side_effect=get_spec_from_base_spec) + + assert get_header_from_base_header( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) == accessors.JumpStartModelsAccessor.get_model_header( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + assert get_spec_from_base_spec( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) == accessors.JumpStartModelsAccessor.get_model_specs( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + assert ( + len( + accessors.JumpStartModelsAccessor._get_manifest( + model_type=JumpStartModelType.PROPRIETARY + ) + ) + > 0 + ) + # necessary because accessors is a static module reload(accessors) @patch("sagemaker.jumpstart.cache.JumpStartModelsCache") -def test_jumpstart_models_cache_set_reset_fxs(mock_model_cache: Mock): +def test_jumpstart_models_cache_set_reset(mock_model_cache: Mock): # test change of region resets cache accessors.JumpStartModelsAccessor.get_model_header( @@ -181,6 +221,50 @@ def test_jumpstart_models_cache_set_reset_fxs(mock_model_cache: Mock): reload(accessors) +@patch("sagemaker.jumpstart.cache.JumpStartModelsCache") +def test_jumpstart_proprietary_models_cache_set_reset(mock_model_cache: Mock): + + # test change of region resets cache + accessors.JumpStartModelsAccessor.get_model_header( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + accessors.JumpStartModelsAccessor.get_model_specs( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + mock_model_cache.assert_called_once() + mock_model_cache.reset_mock() + + accessors.JumpStartModelsAccessor.get_model_header( + region="us-east-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + mock_model_cache.assert_called_once() + mock_model_cache.reset_mock() + + accessors.JumpStartModelsAccessor.get_model_specs( + region="us-west-1", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + mock_model_cache.assert_called_once() + mock_model_cache.reset_mock() + + # necessary because accessors is a static module + reload(accessors) + + class TestS3Accessor(TestCase): bucket = "bucket" diff --git a/tests/unit/sagemaker/jumpstart/test_artifacts.py b/tests/unit/sagemaker/jumpstart/test_artifacts.py index 8acd04f1f6..200c6a5cbc 100644 --- a/tests/unit/sagemaker/jumpstart/test_artifacts.py +++ b/tests/unit/sagemaker/jumpstart/test_artifacts.py @@ -43,7 +43,7 @@ from sagemaker.jumpstart.artifacts.model_packages import _retrieve_model_package_arn from sagemaker.jumpstart.artifacts.model_uris import _retrieve_model_uri -from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec from tests.unit.sagemaker.workflow.conftest import mock_client @@ -342,9 +342,13 @@ class RetrieveModelPackageArnTest(unittest.TestCase): mock_session = Mock(s3_client=mock_client) + @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") - def test_retrieve_model_package_arn(self, patched_get_model_specs): + def test_retrieve_model_package_arn( + self, patched_get_model_specs: Mock, patched_validate_model_id_and_get_type: Mock + ): patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id = "variant-model" region = "us-west-2" @@ -448,9 +452,13 @@ class PrivateJumpStartBucketTest(unittest.TestCase): mock_session = Mock(s3_client=mock_client) + @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") - def test_retrieve_uri_from_gated_bucket(self, patched_get_model_specs): + def test_retrieve_uri_from_gated_bucket( + self, patched_get_model_specs, patched_validate_model_id_and_get_type + ): patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id = "private-model" region = "us-west-2" diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index db1102efb2..348e93e7b7 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -22,7 +22,12 @@ from mock.mock import MagicMock import pytest from mock import patch -from sagemaker.jumpstart.cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache + +from sagemaker.jumpstart.cache import ( + JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, + JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, + JumpStartModelsCache, +) from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, @@ -31,7 +36,9 @@ JumpStartModelHeader, JumpStartModelSpecs, JumpStartVersionedModelId, + JumpStartS3FileType, ) +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import ( get_spec_from_base_spec, patched_retrieval_function, @@ -40,6 +47,8 @@ from tests.unit.sagemaker.jumpstart.constants import ( BASE_MANIFEST, BASE_SPEC, + BASE_PROPRIETARY_SPEC, + BASE_PROPRIETARY_MANIFEST, ) from sagemaker.jumpstart.utils import get_jumpstart_content_bucket @@ -175,6 +184,33 @@ def test_jumpstart_cache_get_header(): semantic_version_str="1.0.*", ) + assert JumpStartModelHeader( + { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], + } + ) == cache.get_header( + model_id="ai21-summarization", + semantic_version_str="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) + + with pytest.raises(KeyError) as e: + cache.get_header( + model_id="ai21-summarization", + semantic_version_str="3.*", + model_type=JumpStartModelType.PROPRIETARY, + ) + assert ( + "Proprietary model 'ai21-summarization' does not support wildcard version identifier '3.*'. " + "You can pin to version '1.1.003'. " + "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " + "for a list of valid model IDs. " in str(e.value) + ) + with pytest.raises(KeyError) as e: cache.get_header( model_id="tensorflow-ic-imagenet-inception-v3-classification-4", @@ -217,6 +253,19 @@ def test_jumpstart_cache_get_header(): "v3-classification-4'?" ) in str(e.value) + with pytest.raises(KeyError) as e: + cache.get_header( + model_id="ai21-summarize", + semantic_version_str="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) + assert ( + "Unable to find model manifest for 'ai21-summarize' with version '1.1.003'. " + "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " + "for updated list of models. " + "Did you mean to use model ID 'ai21-summarization'?" + ) in str(e.value) + with pytest.raises(KeyError): cache.get_header( model_id="tensorflow-ic-imagenet-inception-v3-classification-4", @@ -247,6 +296,27 @@ def test_jumpstart_cache_get_header(): semantic_version_str="*", ) + with pytest.raises(KeyError): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="1.1.004", + model_type=JumpStartModelType.PROPRIETARY, + ) + + with pytest.raises(KeyError): + cache.get_header( + model_id="ai21-summarization", + semantic_version_str="2.*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + with pytest.raises(KeyError): + cache.get_header( + model_id="ai21-summarization", + semantic_version_str="v*", + model_type=JumpStartModelType.PROPRIETARY, + ) + @patch("boto3.client") def test_jumpstart_cache_handles_boto3_issues(mock_boto3_client): @@ -299,6 +369,12 @@ def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client): cache.set_manifest_file_s3_key("some_key1") cache.clear.assert_called_once() + cache.clear.reset_mock() + cache.set_manifest_file_s3_key("some_key1", file_type=JumpStartS3FileType.OPEN_WEIGHT_MANIFEST) + cache.clear.assert_called_once() + with pytest.raises(ValueError): + cache.set_manifest_file_s3_key("some_key1", file_type="unknown_type") + def test_jumpstart_cache_handles_boto3_client_errors(): # Testing get_object @@ -445,15 +521,80 @@ def test_jumpstart_cache_accepts_input_parameters(): assert cache._content_cache._max_cache_items == max_s3_cache_items assert cache._content_cache._expiration_horizon == s3_cache_expiration_horizon assert ( - cache._model_id_semantic_version_manifest_key_cache._max_cache_items + cache._open_weight_model_id_manifest_key_cache._max_cache_items == max_semantic_version_cache_items ) assert ( - cache._model_id_semantic_version_manifest_key_cache._expiration_horizon + cache._open_weight_model_id_manifest_key_cache._expiration_horizon == semantic_version_cache_expiration_horizon ) +def test_jumpstart_proprietary_cache_accepts_input_parameters(): + + region = "us-east-1" + max_s3_cache_items = 1 + s3_cache_expiration_horizon = datetime.timedelta(weeks=2) + max_semantic_version_cache_items = 3 + semantic_version_cache_expiration_horizon = datetime.timedelta(microseconds=4) + bucket = "my-amazing-bucket" + manifest_file_key = "some_s3_key" + proprietary_manifest_file_key = "some_proprietary_s3_key" + + cache = JumpStartModelsCache( + region=region, + max_s3_cache_items=max_s3_cache_items, + s3_cache_expiration_horizon=s3_cache_expiration_horizon, + max_semantic_version_cache_items=max_semantic_version_cache_items, + semantic_version_cache_expiration_horizon=semantic_version_cache_expiration_horizon, + s3_bucket_name=bucket, + manifest_file_s3_key=manifest_file_key, + proprietary_manifest_s3_key=proprietary_manifest_file_key, + ) + + assert ( + cache.get_manifest_file_s3_key(file_type=JumpStartS3FileType.PROPRIETARY_MANIFEST) + == proprietary_manifest_file_key + ) + assert cache.get_region() == region + assert cache.get_bucket() == bucket + assert cache._s3_cache._max_cache_items == max_s3_cache_items + assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon + assert ( + cache._proprietary_model_id_manifest_key_cache._max_cache_items + == max_semantic_version_cache_items + ) + assert ( + cache._proprietary_model_id_manifest_key_cache._expiration_horizon + == semantic_version_cache_expiration_horizon + ) + + +def test_jumpstart_cache_raise_unknown_file_type_exception(): + + region = "us-east-1" + max_s3_cache_items = 1 + s3_cache_expiration_horizon = datetime.timedelta(weeks=2) + max_semantic_version_cache_items = 3 + semantic_version_cache_expiration_horizon = datetime.timedelta(microseconds=4) + bucket = "my-amazing-bucket" + manifest_file_key = "some_s3_key" + proprietary_manifest_file_key = "some_proprietary_s3_key" + + cache = JumpStartModelsCache( + region=region, + max_s3_cache_items=max_s3_cache_items, + s3_cache_expiration_horizon=s3_cache_expiration_horizon, + max_semantic_version_cache_items=max_semantic_version_cache_items, + semantic_version_cache_expiration_horizon=semantic_version_cache_expiration_horizon, + s3_bucket_name=bucket, + manifest_file_s3_key=manifest_file_key, + proprietary_manifest_s3_key=proprietary_manifest_file_key, + ) + with pytest.raises(ValueError): + cache.get_manifest_file_s3_key(file_type="unknown_type") + + @patch("boto3.client") def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client): @@ -605,7 +746,7 @@ def test_jumpstart_cache_makes_correct_s3_calls( with patch("logging.Logger.warning") as mocked_warning_log: cache.get_specs( - model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + model_id="pytorch-ic-imagenet-inception-v3-classification-4", version_str="*" ) mocked_warning_log.assert_called_once_with( "Using model 'pytorch-ic-imagenet-inception-v3-classification-4' with wildcard " @@ -615,7 +756,7 @@ def test_jumpstart_cache_makes_correct_s3_calls( ) mocked_warning_log.reset_mock() cache.get_specs( - model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + model_id="pytorch-ic-imagenet-inception-v3-classification-4", version_str="*" ) mocked_warning_log.assert_not_called() @@ -627,13 +768,85 @@ def test_jumpstart_cache_makes_correct_s3_calls( mock_boto3_client.return_value.head_object.assert_not_called() +@patch("sagemaker.jumpstart.cache.utils.emit_logs_based_on_model_specs") +@patch("boto3.client") +def test_jumpstart_cache_proprietary_manifest_makes_correct_s3_calls( + mock_boto3_client, mock_emit_logs_based_on_model_specs +): + + # test get_header + mock_manifest_json = json.dumps( + [ + { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], + }, + ] + ) + mock_boto3_client.return_value.get_object.return_value = { + "Body": botocore.response.StreamingBody( + io.BytesIO(bytes(mock_manifest_json, "utf-8")), content_length=len(mock_manifest_json) + ), + "ETag": "etag", + } + + mock_boto3_client.return_value.head_object.return_value = {"ETag": "some-hash"} + + bucket_name = get_jumpstart_content_bucket("us-west-2") + client_config = botocore.config.Config(signature_version="my_signature_version") + cache = JumpStartModelsCache( + s3_bucket_name=bucket_name, s3_client_config=client_config, region="us-west-2" + ) + cache.get_header( + model_id="ai21-summarization", + semantic_version_str="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) + + mock_boto3_client.return_value.get_object.assert_called_with( + Bucket=bucket_name, Key=JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY + ) + mock_boto3_client.return_value.head_object.assert_not_called() + + mock_boto3_client.assert_called_with("s3", region_name="us-west-2", config=client_config) + + # test get_specs. manifest already in cache, so only s3 call will be to get specs. + mock_json = json.dumps(BASE_PROPRIETARY_SPEC) + + mock_boto3_client.return_value.reset_mock() + + mock_boto3_client.return_value.get_object.return_value = { + "Body": botocore.response.StreamingBody( + io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json) + ), + "ETag": "etag", + } + + with patch("logging.Logger.warning") as mocked_warning_log: + cache.get_specs( + model_id="ai21-summarization", + version_str="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) + mocked_warning_log.assert_not_called() + + mock_boto3_client.return_value.get_object.assert_called_with( + Bucket=bucket_name, + Key="proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + ) + mock_boto3_client.return_value.head_object.assert_not_called() + + @patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") cache.clear = MagicMock() - cache._model_id_semantic_version_manifest_key_cache = MagicMock() - cache._model_id_semantic_version_manifest_key_cache.get.side_effect = [ + cache._open_weight_model_id_manifest_key_cache = MagicMock() + cache._open_weight_model_id_manifest_key_cache.get.side_effect = [ ( JumpStartVersionedModelId( "tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0" @@ -662,7 +875,7 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): cache.clear.assert_called_once() cache.clear.reset_mock() - cache._model_id_semantic_version_manifest_key_cache.get.side_effect = [ + cache._open_weight_model_id_manifest_key_cache.get.side_effect = [ ( JumpStartVersionedModelId( "tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0" @@ -690,7 +903,18 @@ def test_jumpstart_get_full_manifest(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") raw_manifest = [header.to_json() for header in cache.get_manifest()] - raw_manifest == BASE_MANIFEST + assert raw_manifest == BASE_MANIFEST + + +@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) +@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") +def test_jumpstart_get_full_proprietary_manifest(): + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + raw_manifest = [ + header.to_json() for header in cache.get_manifest(model_type=JumpStartModelType.PROPRIETARY) + ] + + assert raw_manifest == BASE_PROPRIETARY_MANIFEST @patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) @@ -700,54 +924,92 @@ def test_jumpstart_cache_get_specs(): model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" assert get_spec_from_base_spec(model_id=model_id, version=version) == cache.get_specs( - model_id=model_id, semantic_version_str=version + model_id=model_id, version_str=version ) model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" assert get_spec_from_base_spec(model_id=model_id, version="2.0.0") == cache.get_specs( - model_id=model_id, semantic_version_str="2.0.*" + model_id=model_id, version_str="2.0.*" ) model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "1.0.0" assert get_spec_from_base_spec(model_id=model_id, version=version) == cache.get_specs( - model_id=model_id, semantic_version_str=version + model_id=model_id, version_str=version ) model_id = "pytorch-ic-imagenet-inception-v3-classification-4" assert get_spec_from_base_spec(model_id=model_id, version="1.0.0") == cache.get_specs( - model_id=model_id, semantic_version_str="1.*" + model_id=model_id, version_str="1.*" ) model_id = "pytorch-ic-imagenet-inception-v3-classification-4" assert get_spec_from_base_spec(model_id=model_id, version="1.0.0") == cache.get_specs( - model_id=model_id, semantic_version_str="1.0.*" + model_id=model_id, version_str="1.0.*" + ) + + assert get_spec_from_base_spec( + model_id="ai21-summarization", + version="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) == cache.get_specs( + model_id="ai21-summarization", + version_str="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) + + with pytest.raises(KeyError) as e: + cache.get_specs( + model_id="ai21-summarization", + version_str="3.*", + model_type=JumpStartModelType.PROPRIETARY, + ) + assert ( + "Proprietary model 'ai21-summarization' does not support wildcard version identifier '3.*'. " + "You can pin to version '1.1.003'. " + "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " + "for a list of valid model IDs. " in str(e.value) ) with pytest.raises(KeyError): - cache.get_specs(model_id=model_id + "bak", semantic_version_str="*") + cache.get_specs(model_id=model_id + "bak", version_str="*") with pytest.raises(KeyError): - cache.get_specs(model_id=model_id, semantic_version_str="9.*") + cache.get_specs(model_id=model_id, version_str="9.*") with pytest.raises(KeyError): - cache.get_specs(model_id=model_id, semantic_version_str="BAD") + cache.get_specs(model_id=model_id, version_str="BAD") with pytest.raises(KeyError): cache.get_specs( model_id=model_id, - semantic_version_str="2.1.*", + version_str="2.1.*", ) with pytest.raises(KeyError): cache.get_specs( model_id=model_id, - semantic_version_str="3.9.*", + version_str="3.9.*", ) with pytest.raises(KeyError): cache.get_specs( model_id=model_id, - semantic_version_str="5.*", + version_str="5.*", + ) + + model_id, version = "ai21-summarization", "2.0.0" + with pytest.raises(KeyError): + cache.get_specs( + model_id=model_id, + version_str="BAD", + model_type=JumpStartModelType.PROPRIETARY, + ) + + with pytest.raises(KeyError): + cache.get_specs( + model_id=model_id, + version_str="9.*", + model_type=JumpStartModelType.PROPRIETARY, ) @@ -822,9 +1084,7 @@ def test_jumpstart_local_metadata_override_specs( ) model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" - assert JumpStartModelSpecs(BASE_SPEC) == cache.get_specs( - model_id=model_id, semantic_version_str=version - ) + assert JumpStartModelSpecs(BASE_SPEC) == cache.get_specs(model_id=model_id, version_str=version) mocked_is_dir.assert_any_call("/some/directory/metadata/specs/root") mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") @@ -868,7 +1128,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( cache = JumpStartModelsCache(s3_bucket_name="some_bucket") assert get_spec_from_base_spec(model_id=model_id, version=version) == cache.get_specs( - model_id=model_id, semantic_version_str=version + model_id=model_id, version_str=version ) mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") diff --git a/tests/unit/sagemaker/jumpstart/test_exceptions.py b/tests/unit/sagemaker/jumpstart/test_exceptions.py index 555099a753..2307d22474 100644 --- a/tests/unit/sagemaker/jumpstart/test_exceptions.py +++ b/tests/unit/sagemaker/jumpstart/test_exceptions.py @@ -11,10 +11,15 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +import pytest + +from botocore.exceptions import ClientError from sagemaker.jumpstart.exceptions import ( get_wildcard_model_version_msg, get_old_model_version_msg, + get_proprietary_model_subscription_error, + MarketplaceModelSubscriptionError, ) @@ -35,3 +40,32 @@ def test_get_old_model_version_msg(): "Note that models may have different input/output signatures after a major " "version upgrade." == get_old_model_version_msg("mother_of_all_models", "1.0.0", "1.2.3") ) + + +def test_get_marketplace_subscription_error(): + error = ClientError( + error_response={ + "Error": { + "Code": "ValidationException", + "Message": "Caller is not subscribed to the Marketplace listing.", + }, + }, + operation_name="mock-operation", + ) + with pytest.raises(MarketplaceModelSubscriptionError): + get_proprietary_model_subscription_error(error, subscription_link="mock-link") + + error = ClientError( + error_response={ + "Error": { + "Code": "UnknownException", + "Message": "Unknown error raised.", + }, + }, + operation_name="mock-operation", + ) + + try: + get_proprietary_model_subscription_error(error, subscription_link="mock-link") + except MarketplaceModelSubscriptionError: + pytest.fail("MarketplaceModelSubscriptionError should not be raised for unknown error.") diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 60707a5286..a5d1ee3ac2 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -17,6 +17,7 @@ get_prototype_manifest, get_prototype_model_spec, ) +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.notebook_utils import ( _generate_jumpstart_model_versions, get_model_url, @@ -63,7 +64,7 @@ def test_list_jumpstart_scripts( patched_generate_jumpstart_models.assert_called_once_with( **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION ) - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 assert patched_get_model_specs.call_count == 1 patched_get_model_specs.reset_mock() @@ -78,8 +79,8 @@ def test_list_jumpstart_scripts( patched_generate_jumpstart_models.assert_called_once_with( **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION ) - patched_get_manifest.assert_called_once() - assert patched_read_s3_file.call_count == len(PROTOTYPICAL_MODEL_SPECS_DICT) + assert patched_get_manifest.call_count == 2 + assert patched_read_s3_file.call_count == 2 * len(PROTOTYPICAL_MODEL_SPECS_DICT) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @@ -107,7 +108,7 @@ def test_list_jumpstart_tasks( ) # incomplete list, based on mocked metadata patched_generate_jumpstart_models.assert_called_once() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_model_specs.assert_not_called() patched_get_model_specs.reset_mock() @@ -122,7 +123,7 @@ def test_list_jumpstart_tasks( patched_generate_jumpstart_models.assert_called_once_with( **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION ) - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_model_specs.assert_not_called() @@ -154,11 +155,11 @@ def test_list_jumpstart_frameworks( ) patched_generate_jumpstart_models.assert_called_once() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_model_specs.assert_not_called() patched_get_model_specs.reset_mock() - patched_get_manifest.reset_mock() + assert patched_get_manifest.call_count == 2 patched_generate_jumpstart_models.reset_mock() kwargs = { @@ -180,7 +181,7 @@ def test_list_jumpstart_frameworks( patched_generate_jumpstart_models.assert_called_once_with( **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION ) - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 4 patched_get_model_specs.assert_not_called() @@ -238,7 +239,7 @@ def test_list_jumpstart_models_script_filter( kwargs = {"filter": f"training_supported != {val}"} list_jumpstart_models(**kwargs) assert patched_read_s3_file.call_count == manifest_length - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -255,7 +256,7 @@ def test_list_jumpstart_models_script_filter( ("xgboost-classification-model", "1.0.0"), ] assert patched_read_s3_file.call_count == manifest_length - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -264,7 +265,7 @@ def test_list_jumpstart_models_script_filter( models = list_jumpstart_models(**kwargs) assert [] == models assert patched_read_s3_file.call_count == manifest_length - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -287,7 +288,7 @@ def test_list_jumpstart_models_task_filter( kwargs = {"filter": f"task == {val}"} list_jumpstart_models(**kwargs) patched_get_model_specs.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() @@ -295,7 +296,7 @@ def test_list_jumpstart_models_task_filter( kwargs = {"filter": f"task != {val}"} list_jumpstart_models(**kwargs) patched_get_model_specs.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() @@ -312,7 +313,7 @@ def test_list_jumpstart_models_task_filter( ("xgboost-classification-model", "1.0.0"), ] patched_get_model_specs.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() @@ -321,7 +322,7 @@ def test_list_jumpstart_models_task_filter( models = list_jumpstart_models(**kwargs) assert [] == models patched_get_model_specs.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") @@ -348,7 +349,7 @@ def test_list_jumpstart_models_framework_filter( kwargs = {"filter": f"framework == {val}"} list_jumpstart_models(**kwargs) patched_read_s3_file.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -356,7 +357,7 @@ def test_list_jumpstart_models_framework_filter( kwargs = {"filter": f"framework != {val}"} list_jumpstart_models(**kwargs) patched_read_s3_file.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -372,7 +373,7 @@ def test_list_jumpstart_models_framework_filter( ("xgboost-classification-model", "1.0.0"), ] patched_read_s3_file.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -390,8 +391,8 @@ def test_list_jumpstart_models_framework_filter( } models = list_jumpstart_models(**kwargs) assert [("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0")] == models - patched_read_s3_file.assert_called_once() - patched_get_manifest.assert_called_once() + assert patched_read_s3_file.call_count == 2 + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -405,7 +406,7 @@ def test_list_jumpstart_models_framework_filter( models = list_jumpstart_models(**kwargs) assert [] == models patched_read_s3_file.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -420,8 +421,10 @@ def test_list_jumpstart_models_region( list_jumpstart_models(region="some-region") - patched_get_manifest.assert_called_once_with( - region="some-region", s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client + patched_get_manifest.assert_called_with( + region="some-region", + s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @@ -478,7 +481,7 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME): ] == list_jumpstart_models(list_old_models=True, list_versions=True) patched_get_model_specs.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() @@ -526,8 +529,8 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): And("inference_vulnerable is false", "training_vulnerable is false") ) - assert patched_read_s3_file.call_count == num_specs - patched_get_manifest.assert_called_once() + assert patched_read_s3_file.call_count == 2 * num_specs + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -538,8 +541,8 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): And("inference_vulnerable is false", "training_vulnerable is false") ) - assert patched_read_s3_file.call_count == num_specs - patched_get_manifest.assert_called_once() + assert patched_read_s3_file.call_count == 2 * num_specs + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -570,8 +573,8 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str: num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) assert [] == list_jumpstart_models("deprecated equals false") - assert patched_read_s3_file.call_count == num_specs - patched_get_manifest.assert_called_once() + assert patched_read_s3_file.call_count == 2 * num_specs + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -607,6 +610,43 @@ def test_list_jumpstart_models_no_versions( assert list_jumpstart_models(list_versions=False) == all_model_ids + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + def test_list_jumpstart_proprietary_models( + self, + patched_get_model_specs: Mock, + patched_get_manifest: Mock, + ): + patched_get_model_specs.side_effect = get_prototype_model_spec + patched_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + + all_prop_model_ids = [ + "ai21-paraphrase", + "ai21-summarization", + "lighton-mini-instruct40b", + ] + + all_open_weight_model_ids = [ + "catboost-classification-model", + "huggingface-spc-bert-base-cased", + "lightgbm-classification-model", + "mxnet-semseg-fcn-resnet50-ade", + "pytorch-eqa-bert-base-cased", + "sklearn-classification-linear", + "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", + "xgboost-classification-model", + ] + + assert list_jumpstart_models("model_type == proprietary") == all_prop_model_ids + assert list_jumpstart_models("model_type == marketplace") == all_prop_model_ids + assert list_jumpstart_models("model_type == open_weights") == all_open_weight_model_ids + + assert list_jumpstart_models(list_versions=False) == sorted( + all_prop_model_ids + all_open_weight_model_ids + ) + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") def test_list_jumpstart_models_complex_queries( @@ -670,12 +710,20 @@ def test_list_jumpstart_models_multiple_level_index( list_jumpstart_models("hosting_ecr_specs.py_version == py3") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_get_model_url( patched_get_model_specs: Mock, + patched_validate_model_id_and_get_type: Mock, + patched_get_manifest: Mock, ): patched_get_model_specs.side_effect = get_prototype_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS + patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( + region + ) model_id, version = "xgboost-classification-model", "1.0.0" assert "https://xgboost.readthedocs.io/en/latest/" == get_model_url(model_id, version) @@ -686,7 +734,6 @@ def test_get_model_url( ) model_id, version = "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0" - region = "fake-region" patched_get_model_specs.reset_mock() patched_get_model_specs.side_effect = lambda *largs, **kwargs: get_prototype_model_spec( @@ -695,12 +742,13 @@ def test_get_model_url( **{key: value for key, value in kwargs.items() if key != "region"}, ) - get_model_url(model_id, version, region=region) + get_model_url(model_id, version, region="us-west-2") patched_get_model_specs.assert_called_once_with( model_id=model_id, version=version, - region=region, + region="us-west-2", s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 3cc2314a59..1266e011b3 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -7,17 +7,15 @@ import pytest from sagemaker.deserializers import JSONDeserializer from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.enums import MIMEType +from sagemaker.jumpstart.enums import MIMEType, JumpStartModelType from sagemaker import predictor from sagemaker.jumpstart.model import JumpStartModel from sagemaker.jumpstart.utils import verify_model_region_and_return_specs -from sagemaker.serializers import IdentitySerializer -from tests.unit.sagemaker.jumpstart.utils import ( - get_special_model_spec, -) +from sagemaker.serializers import IdentitySerializer, JSONSerializer +from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec, get_spec_from_base_spec @patch("sagemaker.predictor.get_model_id_version_from_endpoint") @@ -54,6 +52,43 @@ def test_jumpstart_predictor_support( assert js_predictor.accept == MIMEType.JSON +@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_proprietary_predictor_support( + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_get_jumpstart_model_id_version_from_endpoint, +): + + patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs + patched_get_model_specs.side_effect = get_spec_from_base_spec + + # version not needed for JumpStart predictor + model_id, model_version = "ai21-summarization", "*" + + patched_get_jumpstart_model_id_version_from_endpoint.return_value = ( + model_id, + model_version, + None, + ) + + js_predictor = predictor.retrieve_default( + endpoint_name="blah", + model_id=model_id, + model_version=model_version, + model_type=JumpStartModelType.PROPRIETARY, + ) + + patched_get_jumpstart_model_id_version_from_endpoint.assert_not_called() + + assert js_predictor.content_type == MIMEType.JSON + assert isinstance(js_predictor.serializer, JSONSerializer) + + assert isinstance(js_predictor.deserializer, JSONDeserializer) + assert js_predictor.accept == MIMEType.JSON + + @patch("sagemaker.predictor.Predictor") @patch("sagemaker.predictor.get_default_predictor") @patch("sagemaker.predictor.get_model_id_version_from_endpoint") @@ -93,6 +128,7 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=mock_session, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) @@ -126,19 +162,19 @@ def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( @patch("sagemaker.predictor.get_model_id_version_from_endpoint") @patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached") -@patch("sagemaker.jumpstart.model.is_valid_model_id") +@patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_serializable_payload_with_predictor( patched_get_model_specs, patched_verify_model_region_and_return_specs, - patched_is_valid_model_id, + patched_validate_model_id_and_get_type, patched_get_object_cached, patched_get_model_id_version_from_endpoint, ): patched_get_object_cached.return_value = base64.b64decode("encodedimage") - patched_is_valid_model_id.return_value = True + patched_validate_model_id_and_get_type.return_value = True patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index c42c15ecf5..3ec5ba30ec 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -32,7 +32,7 @@ JumpStartScriptScope, ) from functools import partial -from sagemaker.jumpstart.enums import JumpStartTag, MIMEType +from sagemaker.jumpstart.enums import JumpStartTag, MIMEType, JumpStartModelType from sagemaker.jumpstart.exceptions import ( DeprecatedJumpStartModelError, VulnerableJumpStartModelError, @@ -71,7 +71,7 @@ def test_get_jumpstart_content_bucket_override(): with patch("logging.Logger.info") as mocked_info_log: random_region = "random_region" assert "some-val" == utils.get_jumpstart_content_bucket(random_region) - mocked_info_log.assert_called_once_with("Using JumpStart bucket override: 'some-val'") + mocked_info_log.assert_called_with("Using JumpStart bucket override: 'some-val'") def test_get_jumpstart_gated_content_bucket(): @@ -1221,7 +1221,7 @@ def test_mime_type_enum_from_str(): class TestIsValidModelId(TestCase): @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") - def test_is_valid_model_id_true( + def test_validate_model_id_and_get_type_true( self, mock_get_model_specs: Mock, mock_get_manifest: Mock, @@ -1235,12 +1235,16 @@ def test_is_valid_model_id_true( mock_session_value = DEFAULT_JUMPSTART_SAGEMAKER_SESSION mock_s3_client_value = mock_session_value.s3_client - patched = partial(utils.is_valid_model_id, sagemaker_session=mock_session_value) + patched = partial( + utils.validate_model_id_and_get_type, sagemaker_session=mock_session_value + ) - with patch("sagemaker.jumpstart.utils.is_valid_model_id", patched): - self.assertTrue(utils.is_valid_model_id("bee")) - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + with patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type", patched): + self.assertTrue(utils.validate_model_id_and_get_type("bee")) + mock_get_manifest.assert_called_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + s3_client=mock_s3_client_value, + model_type=JumpStartModelType.PROPRIETARY, ) mock_get_model_specs.assert_not_called() @@ -1254,14 +1258,20 @@ def test_is_valid_model_id_true( ] mock_get_model_specs.return_value = Mock(training_supported=True) - self.assertTrue(utils.is_valid_model_id("bee", script=JumpStartScriptScope.TRAINING)) - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + self.assertTrue( + utils.validate_model_id_and_get_type("bee", script=JumpStartScriptScope.TRAINING) + ) + mock_get_manifest.assert_called_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + s3_client=mock_s3_client_value, + model_type=JumpStartModelType.PROPRIETARY, ) @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") - def test_is_valid_model_id_false(self, mock_get_model_specs: Mock, mock_get_manifest: Mock): + def test_validate_model_id_and_get_type_false( + self, mock_get_model_specs: Mock, mock_get_manifest: Mock + ): mock_get_manifest.return_value = [ Mock(model_id="ay"), Mock(model_id="bee"), @@ -1271,18 +1281,18 @@ def test_is_valid_model_id_false(self, mock_get_model_specs: Mock, mock_get_mani mock_session_value = DEFAULT_JUMPSTART_SAGEMAKER_SESSION mock_s3_client_value = mock_session_value.s3_client - patched = partial(utils.is_valid_model_id, sagemaker_session=mock_session_value) + patched = partial( + utils.validate_model_id_and_get_type, sagemaker_session=mock_session_value + ) - with patch("sagemaker.jumpstart.utils.is_valid_model_id", patched): + with patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type", patched): - self.assertFalse(utils.is_valid_model_id("dee")) - self.assertFalse(utils.is_valid_model_id("")) - self.assertFalse(utils.is_valid_model_id(None)) - self.assertFalse(utils.is_valid_model_id(set())) + self.assertFalse(utils.validate_model_id_and_get_type("dee")) + self.assertFalse(utils.validate_model_id_and_get_type("")) + self.assertFalse(utils.validate_model_id_and_get_type(None)) + self.assertFalse(utils.validate_model_id_and_get_type(set())) - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value - ) + mock_get_manifest.assert_called() mock_get_model_specs.assert_not_called() @@ -1294,30 +1304,48 @@ def test_is_valid_model_id_false(self, mock_get_model_specs: Mock, mock_get_mani Mock(model_id="bee"), Mock(model_id="see"), ] - self.assertFalse(utils.is_valid_model_id("dee", script=JumpStartScriptScope.TRAINING)) - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + self.assertFalse( + utils.validate_model_id_and_get_type("dee", script=JumpStartScriptScope.TRAINING) + ) + mock_get_manifest.assert_called_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + s3_client=mock_s3_client_value, + model_type=JumpStartModelType.PROPRIETARY, ) mock_get_manifest.reset_mock() - self.assertFalse(utils.is_valid_model_id("dee", script=JumpStartScriptScope.TRAINING)) - self.assertFalse(utils.is_valid_model_id("", script=JumpStartScriptScope.TRAINING)) - self.assertFalse(utils.is_valid_model_id(None, script=JumpStartScriptScope.TRAINING)) - self.assertFalse(utils.is_valid_model_id(set(), script=JumpStartScriptScope.TRAINING)) + self.assertFalse( + utils.validate_model_id_and_get_type("dee", script=JumpStartScriptScope.TRAINING) + ) + self.assertFalse( + utils.validate_model_id_and_get_type("", script=JumpStartScriptScope.TRAINING) + ) + self.assertFalse( + utils.validate_model_id_and_get_type(None, script=JumpStartScriptScope.TRAINING) + ) + self.assertFalse( + utils.validate_model_id_and_get_type(set(), script=JumpStartScriptScope.TRAINING) + ) mock_get_model_specs.assert_not_called() - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + mock_get_manifest.assert_called_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + s3_client=mock_s3_client_value, + model_type=JumpStartModelType.PROPRIETARY, ) mock_get_manifest.reset_mock() mock_get_model_specs.reset_mock() mock_get_model_specs.return_value = Mock(training_supported=False) - self.assertTrue(utils.is_valid_model_id("ay", script=JumpStartScriptScope.TRAINING)) - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + self.assertTrue( + utils.validate_model_id_and_get_type("ay", script=JumpStartScriptScope.TRAINING) + ) + mock_get_manifest.assert_called_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + s3_client=mock_s3_client_value, + model_type=JumpStartModelType.PROPRIETARY, ) diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index a809b32a24..0d1f6eb2d1 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -30,13 +30,16 @@ HubType, HubContentType, ) - +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.utils import get_formatted_manifest from tests.unit.sagemaker.jumpstart.constants import ( PROTOTYPICAL_MODEL_SPECS_DICT, BASE_MANIFEST, BASE_SPEC, + BASE_PROPRIETARY_MANIFEST, + BASE_PROPRIETARY_SPEC, BASE_HEADER, + BASE_PROPRIETARY_HEADER, SPECIAL_MODEL_SPECS_DICT, ) @@ -47,11 +50,16 @@ def get_header_from_base_header( model_id: str = None, semantic_version_str: str = None, version: str = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelHeader: if version and semantic_version_str: raise ValueError("Cannot specify both `version` and `semantic_version_str` fields.") + if model_type == JumpStartModelType.PROPRIETARY: + spec = copy.deepcopy(BASE_PROPRIETARY_HEADER) + return JumpStartModelHeader(spec) + if all( [ "pytorch" not in model_id, @@ -82,7 +90,10 @@ def get_header_from_base_header( def get_prototype_manifest( region: str = JUMPSTART_DEFAULT_REGION_NAME, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[JumpStartModelHeader]: + if model_type == JumpStartModelType.PROPRIETARY: + return [JumpStartModelHeader(spec) for spec in BASE_PROPRIETARY_MANIFEST] return [ get_header_from_base_header(region=region, model_id=model_id, version=version) for model_id in PROTOTYPICAL_MODEL_SPECS_DICT.keys() @@ -96,6 +107,7 @@ def get_prototype_model_spec( version: str = None, hub_arn: Optional[str] = None, s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. @@ -112,6 +124,7 @@ def get_special_model_spec( version: str = None, hub_arn: Optional[str] = None, s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. This is reserved @@ -128,6 +141,7 @@ def get_special_model_spec_for_inference_component_based_endpoint( version: str = None, hub_arn: Optional[str] = None, s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID and adding @@ -148,16 +162,24 @@ def get_spec_from_base_spec( _obj: JumpStartModelsCache = None, region: str = None, model_id: str = None, - semantic_version_str: str = None, + version_str: str = None, version: str = None, hub_arn: Optional[str] = None, hub_model_arn: Optional[str] = None, s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: - if version and semantic_version_str: + if version and version_str: raise ValueError("Cannot specify both `version` and `semantic_version_str` fields.") + if model_type == JumpStartModelType.PROPRIETARY: + spec = copy.deepcopy(BASE_PROPRIETARY_SPEC) + spec["version"] = version or version_str + spec["model_id"] = model_id + + return JumpStartModelSpecs(spec) + if model_id is not None: if all( [ @@ -181,7 +203,7 @@ def get_spec_from_base_spec( spec = copy.deepcopy(BASE_SPEC) - spec["version"] = version or semantic_version_str + spec["version"] = version or version_str spec["model_id"] = model_id return JumpStartModelSpecs(spec) @@ -194,11 +216,14 @@ def patched_retrieval_function( ) -> JumpStartCachedContentValue: datatype, id_info = key.data_type, key.id_info - if datatype == JumpStartS3FileType.MANIFEST: - return JumpStartCachedContentValue(formatted_content=get_formatted_manifest(BASE_MANIFEST)) + if datatype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: + + return JumpStartCachedContentValue( + formatted_content=get_formatted_manifest(BASE_MANIFEST) + ) - if datatype == JumpStartS3FileType.SPECS: - _, model_id, specs_version = id_info.split("/") + if datatype == JumpStartCachedContentValue.OPEN_WEIGHT_SPECS: + _, model_id, specs_version = s3_key.split("/") version = specs_version.replace("specs_v", "").replace(".json", "") return JumpStartCachedContentValue( formatted_content=get_spec_from_base_spec(model_id=model_id, version=version) @@ -214,7 +239,23 @@ def patched_retrieval_function( if datatype == HubType.HUB: return None - raise ValueError(f"Bad value for filetype: {datatype}") + if datatype == JumpStartS3FileType.PROPRIETARY_MANIFEST: + return JumpStartCachedContentValue( + formatted_content=get_formatted_manifest(BASE_PROPRIETARY_MANIFEST) + ) + + if datatype == JumpStartS3FileType.PROPRIETARY_SPECS: + _, model_id, specs_version = s3_key.split("/") + version = specs_version.replace("proprietary_specs_", "").replace(".json", "") + return JumpStartCachedContentValue( + formatted_content=get_spec_from_base_spec( + model_id=model_id, + version=version, + model_type=JumpStartModelType.PROPRIETARY, + ) + ) + + raise ValueError(f"Bad value for datatype: {datatype}") def overwrite_dictionary( diff --git a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py index afba23b5a4..593400ea22 100644 --- a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py +++ b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py @@ -18,6 +18,7 @@ import pytest from sagemaker import metric_definitions +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec @@ -25,10 +26,14 @@ mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_default_metric_definitions(patched_get_model_specs): +def test_jumpstart_default_metric_definitions( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -47,7 +52,12 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="*", s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version="*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -63,7 +73,12 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="1.*", s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version="1.*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index cde3258133..2bb327c26f 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -19,19 +19,24 @@ from sagemaker import model_uris from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.model_uris.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_common_model_uri( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -48,6 +53,7 @@ def test_jumpstart_common_model_uri( version="*", s3_client=mock_client, hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -66,6 +72,7 @@ def test_jumpstart_common_model_uri( version="1.*", s3_client=mock_client, hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -85,6 +92,7 @@ def test_jumpstart_common_model_uri( version="*", s3_client=mock_client, hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -104,6 +112,7 @@ def test_jumpstart_common_model_uri( version="1.*", s3_client=mock_client, hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index 1ad25f962f..2a4d913a75 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -18,6 +18,7 @@ import pytest from sagemaker import resource_requirements +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker.jumpstart.artifacts.resource_requirements import ( REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP, @@ -26,10 +27,14 @@ from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_resource_requirements(patched_get_model_specs): +def test_jumpstart_resource_requirements( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS region = "us-west-2" mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -46,7 +51,12 @@ def test_jumpstart_resource_requirements(patched_get_model_specs): assert default_inference_resource_requirements.requests["memory"] == 34360 patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -100,9 +110,14 @@ def test_jumpstart_resource_requirements_instance_type_variants(patched_get_mode } +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs): +def test_jumpstart_no_supported_resource_requirements( + patched_get_model_specs, patched_validate_model_id_and_get_type +): + patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "no-supported-instance-types-model", "*" region = "us-west-2" @@ -119,7 +134,12 @@ def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs): assert default_inference_resource_requirements is None patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 5811a9d822..87364a16fc 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -19,19 +19,24 @@ from sagemaker import script_uris from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.script_uris.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_common_script_uri( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -48,6 +53,7 @@ def test_jumpstart_common_script_uri( version="*", s3_client=mock_client, hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -66,6 +72,7 @@ def test_jumpstart_common_script_uri( version="1.*", s3_client=mock_client, hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -85,6 +92,7 @@ def test_jumpstart_common_script_uri( version="*", s3_client=mock_client, hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -104,6 +112,7 @@ def test_jumpstart_common_script_uri( version="1.*", s3_client=mock_client, hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py index 94354e782e..73030b508c 100644 --- a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py +++ b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py @@ -19,18 +19,23 @@ from sagemaker import base_serializers, serializers from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_serializers( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -46,20 +51,29 @@ def test_jumpstart_default_serializers( assert isinstance(default_serializer, base_serializers.IdentitySerializer) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_serializer_options( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -82,5 +96,10 @@ def test_jumpstart_serializer_options( ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) From 95780d344d8cbee2a6f65fdf3b102e4832c42ad6 Mon Sep 17 00:00:00 2001 From: evakravi <69981223+evakravi@users.noreply.github.com> Date: Tue, 12 Mar 2024 20:12:10 -0400 Subject: [PATCH 31/42] chore: emit warning when no instance specific gated training env var is available, and raise exception when accept_eula flag is not supplied (#4485) * fix: raise exception when no instance specific gated training env var available * chore: raise client exception if accept_eula flag is not set for gated models * chore: address flake8 errors * chore: emit warning when instance type is chosen with no gated training artifacts --- .../artifacts/environment_variables.py | 34 +- src/sagemaker/jumpstart/factory/estimator.py | 21 + src/sagemaker/jumpstart/types.py | 4 + src/sagemaker/jumpstart/utils.py | 22 +- .../jumpstart/test_default.py | 65 + tests/unit/sagemaker/jumpstart/constants.py | 1236 +++++++++++++++++ .../jumpstart/estimator/test_estimator.py | 46 +- tests/unit/sagemaker/jumpstart/test_utils.py | 10 +- 8 files changed, 1388 insertions(+), 50 deletions(-) diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index b85cfe4572..a800d06e3d 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -12,10 +12,11 @@ # language governing permissions and limitations under the License. """This module contains functions for obtaining JumpStart environment variables.""" from __future__ import absolute_import -from typing import Dict, Optional +from typing import Callable, Dict, Optional, Set from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, + JUMPSTART_LOGGER, SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, ) from sagemaker.jumpstart.enums import ( @@ -114,7 +115,9 @@ def _retrieve_default_environment_variables( default_environment_variables.update(instance_specific_environment_variables) - gated_model_env_var: Optional[str] = _retrieve_gated_model_uri_env_var_value( + retrieve_gated_env_var_for_instance_type: Callable[ + [str], Optional[str] + ] = lambda instance_type: _retrieve_gated_model_uri_env_var_value( model_id=model_id, model_version=model_version, hub_arn=hub_arn, @@ -125,6 +128,33 @@ def _retrieve_default_environment_variables( instance_type=instance_type, ) + gated_model_env_var: Optional[str] = retrieve_gated_env_var_for_instance_type( + instance_type + ) + + if gated_model_env_var is None and model_specs.is_gated_model(): + + possible_env_vars: Set[str] = { + retrieve_gated_env_var_for_instance_type(instance_type) + for instance_type in model_specs.supported_training_instance_types + } + + # If all officially supported instance types have the same underlying artifact, + # we can use this artifact with high confidence that it'll succeed with + # an arbitrary instance. + if len(possible_env_vars) == 1: + gated_model_env_var = list(possible_env_vars)[0] + + # If this model does not have 1 artifact for all supported instance types, + # we cannot determine which artifact to use for an arbitrary instance. + else: + log_msg = ( + f"'{model_id}' does not support {instance_type} instance type" + " for training. Please use one of the following instance types: " + f"{', '.join(model_specs.supported_training_instance_types)}." + ) + JUMPSTART_LOGGER.warning(log_msg) + if gated_model_env_var is not None: default_environment_variables.update( {SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY: gated_model_env_var} diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 0074374e1a..fb598256fa 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -63,6 +63,7 @@ from sagemaker.jumpstart.utils import ( add_hub_arn_tags, add_jumpstart_model_id_version_tags, + get_eula_message, update_dict_if_key_not_present, resolve_estimator_sagemaker_config_field, verify_model_region_and_return_specs, @@ -617,6 +618,26 @@ def _add_env_to_kwargs( value, ) + environment = getattr(kwargs, "environment", {}) or {} + if ( + environment.get(SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY) + and str(environment.get("accept_eula", "")).lower() != "true" + ): + model_specs = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + region=kwargs.region, + scope=JumpStartScriptScope.TRAINING, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, + ) + if model_specs.is_gated_model(): + raise ValueError( + "Need to define ‘accept_eula'='true' within Environment. " + f"{get_eula_message(model_specs, kwargs.region)}" + ) + return kwargs diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 4dae2383d1..b70d93c67d 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -991,6 +991,10 @@ def use_training_model_artifact(self) -> bool: # otherwise, return true is a training model package is not set return len(self.training_model_package_artifact_uris or {}) == 0 + def is_gated_model(self) -> bool: + """Returns True if the model has a EULA key or the model bucket is gated.""" + return self.gated_bucket or self.hosting_eula_key is not None + def supports_incremental_training(self) -> bool: """Returns True if the model supports incremental training.""" return self.incremental_training_supported diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index f94f87281d..5ba6a3b98b 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -491,21 +491,25 @@ def update_inference_tags_with_jumpstart_training_tags( return inference_tags +def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str: + """Returns EULA message to display if one is available, else empty string.""" + if model_specs.hosting_eula_key is None: + return "" + return ( + f"Model '{model_specs.model_id}' requires accepting end-user license agreement (EULA). " + f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}." + f"amazonaws.com{'.cn' if region.startswith('cn-') else ''}" + f"/{model_specs.hosting_eula_key} for terms of use." + ) + + def emit_logs_based_on_model_specs( model_specs: JumpStartModelSpecs, region: str, s3_client: boto3.client ) -> None: """Emits logs based on model specs and region.""" if model_specs.hosting_eula_key: - constants.JUMPSTART_LOGGER.info( - "Model '%s' requires accepting end-user license agreement (EULA). " - "See https://%s.s3.%s.amazonaws.com%s/%s for terms of use.", - model_specs.model_id, - get_jumpstart_content_bucket(region=region), - region, - ".cn" if region.startswith("cn-") else "", - model_specs.hosting_eula_key, - ) + constants.JUMPSTART_LOGGER.info(get_eula_message(model_specs, region)) full_version: str = model_specs.version diff --git a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py index d1abc2bb68..8971b03aff 100644 --- a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py +++ b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py @@ -18,6 +18,7 @@ import pytest from sagemaker import environment_variables +from sagemaker.jumpstart.utils import get_jumpstart_gated_content_bucket from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec @@ -207,6 +208,70 @@ def test_jumpstart_sdk_environment_variables( ) +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_sdk_environment_variables_1_artifact_all_variants(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_special_model_spec + + model_id = "gemma-model-1-artifact" + region = "us-west-2" + + assert { + "SageMakerGatedModelS3Uri": f"s3://{get_jumpstart_gated_content_bucket(region)}/" + "huggingface-training/train-huggingface-llm-gemma-7b-instruct.tar.gz" + } == environment_variables.retrieve_default( + region=region, + model_id=model_id, + model_version="*", + include_aws_sdk_env_vars=False, + sagemaker_session=mock_session, + instance_type="ml.p3.2xlarge", + script="training", + ) + + +@patch("sagemaker.jumpstart.artifacts.environment_variables.JUMPSTART_LOGGER") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_sdk_environment_variables_no_gated_env_var_available( + patched_get_model_specs, patched_jumpstart_logger +): + + patched_get_model_specs.side_effect = get_special_model_spec + + model_id = "gemma-model" + region = "us-west-2" + + assert {} == environment_variables.retrieve_default( + region=region, + model_id=model_id, + model_version="*", + include_aws_sdk_env_vars=False, + sagemaker_session=mock_session, + instance_type="ml.p3.2xlarge", + script="training", + ) + + patched_jumpstart_logger.warning.assert_called_once_with( + "'gemma-model' does not support ml.p3.2xlarge instance type for " + "training. Please use one of the following instance types: " + "ml.g5.12xlarge, ml.g5.24xlarge, ml.g5.48xlarge, ml.p4d.24xlarge." + ) + + # assert that supported instance types succeed + assert { + "SageMakerGatedModelS3Uri": f"s3://{get_jumpstart_gated_content_bucket(region)}/" + "huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-7b-instruct.tar.gz" + } == environment_variables.retrieve_default( + region=region, + model_id=model_id, + model_version="*", + include_aws_sdk_env_vars=False, + sagemaker_session=mock_session, + instance_type="ml.g5.24xlarge", + script="training", + ) + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_sdk_environment_variables_instance_type_overrides(patched_get_model_specs): diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index cb0f4831b7..ef2c4c30a4 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -14,6 +14,1242 @@ SPECIAL_MODEL_SPECS_DICT = { + "gemma-model": { + "model_id": "huggingface-llm-gemma-7b-instruct", + "url": "https://huggingface.co/google/gemma-7b-it", + "version": "1.1.0", + "min_sdk_version": "2.189.0", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface-llm", + "framework_version": "1.4.2", + "py_version": "py310", + "huggingface_transformers_version": "4.33.2", + }, + "hosting_artifact_key": "huggingface-llm/huggingface-llm-gemma-7b-instruct/artifacts/inference/v1.0.0/", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-llm/huggingface-llm-gemma-7b-i" + "nstruct/artifacts/inference-prepack/v1.0.0/", + "hosting_prepacked_artifact_version": "1.0.0", + "hosting_use_script_uri": False, + "hosting_eula_key": "fmhMetadata/terms/gemmaTerms.txt", + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [ + "accelerate==0.26.1", + "bitsandbytes==0.42.0", + "deepspeed==0.10.3", + "docstring-parser==0.15", + "flash_attn==2.5.5", + "ninja==1.11.1", + "packaging==23.2", + "peft==0.8.2", + "py_cpuinfo==9.0.0", + "rich==13.7.0", + "safetensors==0.4.2", + "sagemaker_jumpstart_huggingface_script_utilities==1.2.1", + "sagemaker_jumpstart_script_utilities==1.1.9", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + "shtab==1.6.5", + "tokenizers==0.15.1", + "transformers==4.38.1", + "trl==0.7.10", + "tyro==0.7.2", + ], + "training_vulnerabilities": [], + "deprecated": False, + "hyperparameters": [ + { + "name": "peft_type", + "type": "text", + "default": "lora", + "options": ["lora", "None"], + "scope": "algorithm", + }, + { + "name": "instruction_tuned", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "chat_dataset", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "epoch", + "type": "int", + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "learning_rate", + "type": "float", + "default": 0.0001, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "lora_r", + "type": "int", + "default": 64, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + {"name": "lora_alpha", "type": "int", "default": 16, "min": 0, "scope": "algorithm"}, + { + "name": "lora_dropout", + "type": "float", + "default": 0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + {"name": "bits", "type": "int", "default": 4, "scope": "algorithm"}, + { + "name": "double_quant", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "quant_type", + "type": "text", + "default": "nf4", + "options": ["fp4", "nf4"], + "scope": "algorithm", + }, + { + "name": "per_device_train_batch_size", + "type": "int", + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "per_device_eval_batch_size", + "type": "int", + "default": 2, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "warmup_ratio", + "type": "float", + "default": 0.1, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "train_from_scratch", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "fp16", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "bf16", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "evaluation_strategy", + "type": "text", + "default": "steps", + "options": ["steps", "epoch", "no"], + "scope": "algorithm", + }, + { + "name": "eval_steps", + "type": "int", + "default": 20, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "gradient_accumulation_steps", + "type": "int", + "default": 4, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "logging_steps", + "type": "int", + "default": 8, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "weight_decay", + "type": "float", + "default": 0.2, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "load_best_model_at_end", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "max_train_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_val_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "seed", + "type": "int", + "default": 10, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "max_input_length", + "type": "int", + "default": 2048, + "min": -1, + "scope": "algorithm", + }, + { + "name": "validation_split_ratio", + "type": "float", + "default": 0.2, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "train_data_split_seed", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "preprocessing_num_workers", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + {"name": "max_steps", "type": "int", "default": -1, "scope": "algorithm"}, + { + "name": "gradient_checkpointing", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "early_stopping_patience", + "type": "int", + "default": 3, + "min": 1, + "scope": "algorithm", + }, + { + "name": "early_stopping_threshold", + "type": "float", + "default": 0.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "adam_beta1", + "type": "float", + "default": 0.9, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_beta2", + "type": "float", + "default": 0.999, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_epsilon", + "type": "float", + "default": 1e-08, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "max_grad_norm", + "type": "float", + "default": 1.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "label_smoothing_factor", + "type": "float", + "default": 0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "logging_first_step", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "logging_nan_inf_filter", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "save_strategy", + "type": "text", + "default": "steps", + "options": ["no", "epoch", "steps"], + "scope": "algorithm", + }, + {"name": "save_steps", "type": "int", "default": 500, "min": 1, "scope": "algorithm"}, + {"name": "save_total_limit", "type": "int", "default": 1, "scope": "algorithm"}, + { + "name": "dataloader_drop_last", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "dataloader_num_workers", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "eval_accumulation_steps", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "auto_find_batch_size", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "lr_scheduler_type", + "type": "text", + "default": "constant_with_warmup", + "options": ["constant_with_warmup", "linear"], + "scope": "algorithm", + }, + {"name": "warmup_steps", "type": "int", "default": 0, "min": 0, "scope": "algorithm"}, + { + "name": "deepspeed", + "type": "text", + "default": "False", + "options": ["False"], + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "training_script_key": "source-directory-tarballs/huggingface/transfer_learning/llm/v1.1.1/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-tarballs/huggingface/transfer_learning/" + "llm/prepack/v1.1.1/sourcedir.tar.gz", + "training_prepacked_script_version": "1.1.1", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, + "training_artifact_key": "huggingface-training/train-huggingface-llm-gemma-7b-instruct.tar.gz", + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "8191", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "8192", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_BATCH_PREFILL_TOKENS", + "type": "text", + "default": "8191", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + {"Name": "huggingface-textgeneration:train-loss", "Regex": "'loss': ([0-9]+\\.[0-9]+)"}, + ], + "default_inference_instance_type": "ml.g5.12xlarge", + "supported_inference_instance_types": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "default_training_instance_type": "ml.g5.12xlarge", + "supported_training_instance_types": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 1200, + "container_startup_health_check_timeout": 1200, + }, + "estimator_kwargs": { + "encrypt_inter_container_traffic": True, + "disable_output_compression": True, + "max_run": 360000, + }, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_volume_size": 512, + "training_volume_size": 512, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/oasst_top/train/", + "validation_supported": True, + "fine_tuning_supported": True, + "resource_name_base": "hf-llm-gemma-7b-instruct", + "gated_bucket": True, + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/h" + "uggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-training:" + "2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + }, + "variants": { + "g4dn": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "huggingface-training/g4dn/v1.0.0/train-hugg" + "ingface-llm-gemma-7b-instruct.tar.gz" + }, + }, + "g5": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "huggingface-training/g5/v1.0.0/train-huggingf" + "ace-llm-gemma-7b-instruct.tar.gz" + }, + }, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "huggingface-training/p3dn/v1.0.0/train-hugg" + "ingface-llm-gemma-7b-instruct.tar.gz" + }, + }, + "p4d": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "huggingface-training/" + "p4d/v1.0.0/train-huggingface-llm-gemma-7b-instruct.tar.gz" + }, + }, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "hosting_resource_requirements": {"min_memory_mb": 98304, "num_accelerators": 4}, + "dynamic_container_deployment_supported": True, + }, + }, + "gemma-model-1-artifact": { + "model_id": "huggingface-llm-gemma-7b-instruct", + "url": "https://huggingface.co/google/gemma-7b-it", + "version": "1.1.0", + "min_sdk_version": "2.189.0", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface-llm", + "framework_version": "1.4.2", + "py_version": "py310", + "huggingface_transformers_version": "4.33.2", + }, + "hosting_artifact_key": "huggingface-llm/huggingface-llm-gemma-7b-instruct/artifacts/inference/v1.0.0/", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-llm/huggingface-llm-gemma-7b-i" + "nstruct/artifacts/inference-prepack/v1.0.0/", + "hosting_prepacked_artifact_version": "1.0.0", + "hosting_use_script_uri": False, + "hosting_eula_key": "fmhMetadata/terms/gemmaTerms.txt", + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [ + "accelerate==0.26.1", + "bitsandbytes==0.42.0", + "deepspeed==0.10.3", + "docstring-parser==0.15", + "flash_attn==2.5.5", + "ninja==1.11.1", + "packaging==23.2", + "peft==0.8.2", + "py_cpuinfo==9.0.0", + "rich==13.7.0", + "safetensors==0.4.2", + "sagemaker_jumpstart_huggingface_script_utilities==1.2.1", + "sagemaker_jumpstart_script_utilities==1.1.9", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + "shtab==1.6.5", + "tokenizers==0.15.1", + "transformers==4.38.1", + "trl==0.7.10", + "tyro==0.7.2", + ], + "training_vulnerabilities": [], + "deprecated": False, + "hyperparameters": [ + { + "name": "peft_type", + "type": "text", + "default": "lora", + "options": ["lora", "None"], + "scope": "algorithm", + }, + { + "name": "instruction_tuned", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "chat_dataset", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "epoch", + "type": "int", + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "learning_rate", + "type": "float", + "default": 0.0001, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "lora_r", + "type": "int", + "default": 64, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + {"name": "lora_alpha", "type": "int", "default": 16, "min": 0, "scope": "algorithm"}, + { + "name": "lora_dropout", + "type": "float", + "default": 0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + {"name": "bits", "type": "int", "default": 4, "scope": "algorithm"}, + { + "name": "double_quant", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "quant_type", + "type": "text", + "default": "nf4", + "options": ["fp4", "nf4"], + "scope": "algorithm", + }, + { + "name": "per_device_train_batch_size", + "type": "int", + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "per_device_eval_batch_size", + "type": "int", + "default": 2, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "warmup_ratio", + "type": "float", + "default": 0.1, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "train_from_scratch", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "fp16", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "bf16", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "evaluation_strategy", + "type": "text", + "default": "steps", + "options": ["steps", "epoch", "no"], + "scope": "algorithm", + }, + { + "name": "eval_steps", + "type": "int", + "default": 20, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "gradient_accumulation_steps", + "type": "int", + "default": 4, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "logging_steps", + "type": "int", + "default": 8, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "weight_decay", + "type": "float", + "default": 0.2, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "load_best_model_at_end", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "max_train_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_val_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "seed", + "type": "int", + "default": 10, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "max_input_length", + "type": "int", + "default": 2048, + "min": -1, + "scope": "algorithm", + }, + { + "name": "validation_split_ratio", + "type": "float", + "default": 0.2, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "train_data_split_seed", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "preprocessing_num_workers", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + {"name": "max_steps", "type": "int", "default": -1, "scope": "algorithm"}, + { + "name": "gradient_checkpointing", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "early_stopping_patience", + "type": "int", + "default": 3, + "min": 1, + "scope": "algorithm", + }, + { + "name": "early_stopping_threshold", + "type": "float", + "default": 0.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "adam_beta1", + "type": "float", + "default": 0.9, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_beta2", + "type": "float", + "default": 0.999, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_epsilon", + "type": "float", + "default": 1e-08, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "max_grad_norm", + "type": "float", + "default": 1.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "label_smoothing_factor", + "type": "float", + "default": 0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "logging_first_step", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "logging_nan_inf_filter", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "save_strategy", + "type": "text", + "default": "steps", + "options": ["no", "epoch", "steps"], + "scope": "algorithm", + }, + {"name": "save_steps", "type": "int", "default": 500, "min": 1, "scope": "algorithm"}, + {"name": "save_total_limit", "type": "int", "default": 1, "scope": "algorithm"}, + { + "name": "dataloader_drop_last", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "dataloader_num_workers", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "eval_accumulation_steps", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "auto_find_batch_size", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "lr_scheduler_type", + "type": "text", + "default": "constant_with_warmup", + "options": ["constant_with_warmup", "linear"], + "scope": "algorithm", + }, + {"name": "warmup_steps", "type": "int", "default": 0, "min": 0, "scope": "algorithm"}, + { + "name": "deepspeed", + "type": "text", + "default": "False", + "options": ["False"], + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "training_script_key": "source-directory-tarballs/huggingface/transfer_learning/llm/v1.1.1/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-tarballs/huggingface/transfer_learning/" + "llm/prepack/v1.1.1/sourcedir.tar.gz", + "training_prepacked_script_version": "1.1.1", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, + "training_artifact_key": "huggingface-training/train-huggingface-llm-gemma-7b-instruct.tar.gz", + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "8191", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "8192", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_BATCH_PREFILL_TOKENS", + "type": "text", + "default": "8191", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + {"Name": "huggingface-textgeneration:train-loss", "Regex": "'loss': ([0-9]+\\.[0-9]+)"}, + ], + "default_inference_instance_type": "ml.g5.12xlarge", + "supported_inference_instance_types": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "default_training_instance_type": "ml.g5.12xlarge", + "supported_training_instance_types": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 1200, + "container_startup_health_check_timeout": 1200, + }, + "estimator_kwargs": { + "encrypt_inter_container_traffic": True, + "disable_output_compression": True, + "max_run": 360000, + }, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_volume_size": 512, + "training_volume_size": 512, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/oasst_top/train/", + "validation_supported": True, + "fine_tuning_supported": True, + "resource_name_base": "hf-llm-gemma-7b-instruct", + "gated_bucket": True, + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/h" + "uggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-training:" + "2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + }, + "variants": { + "g4dn": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "huggingface-training/train-hugg" + "ingface-llm-gemma-7b-instruct.tar.gz" + }, + }, + "g5": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "huggingface-training/train-hugg" + "ingface-llm-gemma-7b-instruct.tar.gz" + }, + }, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "huggingface-training/train-hugg" + "ingface-llm-gemma-7b-instruct.tar.gz" + }, + }, + "p4d": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "huggingface-training/train-hugg" + "ingface-llm-gemma-7b-instruct.tar.gz" + }, + }, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "hosting_resource_requirements": {"min_memory_mb": 98304, "num_accelerators": 4}, + "dynamic_container_deployment_supported": True, + }, + }, "env-var-variant-model": { "model_id": "huggingface-llm-falcon-180b-bf16", "url": "https://huggingface.co/tiiuae/falcon-180B", diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index f381573fe8..e0e3348cc2 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -460,39 +460,21 @@ def test_gated_model_s3_uri( mock_session_estimator.return_value = sagemaker_session mock_session_model.return_value = sagemaker_session - JumpStartEstimator( - model_id=model_id, - environment={ - "accept_eula": "false", - "what am i": "doing", - "SageMakerGatedModelS3Uri": "none of your business", - }, - ) - - mock_estimator_init.assert_called_once_with( - instance_type="ml.p3.2xlarge", - instance_count=1, - image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117", - source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" - "meta/transfer_learning/textgeneration/v1.0.0/sourcedir.tar.gz", - entry_point="transfer_learning.py", - role=execution_role, - sagemaker_session=sagemaker_session, - max_run=360000, - enable_network_isolation=True, - encrypt_inter_container_traffic=True, - environment={ - "accept_eula": "false", - "what am i": "doing", - "SageMakerGatedModelS3Uri": "none of your business", - }, - tags=[ - { - "Key": "sagemaker-sdk:jumpstart-model-id", - "Value": "js-gated-artifact-trainable-model", + with pytest.raises(ValueError) as e: + JumpStartEstimator( + model_id=model_id, + environment={ + "accept_eula": "false", + "what am i": "doing", + "SageMakerGatedModelS3Uri": "none of your business", }, - {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "2.0.0"}, - ], + ) + assert str(e.value) == ( + "Need to define ‘accept_eula'='true' within Environment. " + "Model 'meta-textgeneration-llama-2-7b-f' requires accepting end-user " + "license agreement (EULA). See " + "https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com/fmhMetadata/eula/llamaEula.txt" + " for terms of use." ) mock_estimator_init.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 3ec5ba30ec..fa1a3fc72f 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -946,13 +946,9 @@ def make_accept_eula_inference_spec(*largs, **kwargs): make_accept_eula_inference_spec(), "us-east-1", MOCK_CLIENT ) mocked_info_log.assert_any_call( - "Model '%s' requires accepting end-user license agreement (EULA). " - "See https://%s.s3.%s.amazonaws.com%s/%s for terms of use.", - "pytorch-eqa-bert-base-cased", - "jumpstart-cache-prod-us-east-1", - "us-east-1", - "", - "read/the/fine/print.txt", + "Model 'pytorch-eqa-bert-base-cased' requires accepting end-user license agreement (EULA). " + "See https://jumpstart-cache-prod-us-east-1.s3.us-east-1.amazonaws.com/read/the/fine/print.txt" + " for terms of use.", ) From 2d638ebee60f9bd861d56d891bd3b3ff86b2f306 Mon Sep 17 00:00:00 2001 From: evakravi <69981223+evakravi@users.noreply.github.com> Date: Fri, 16 Feb 2024 12:32:39 -0500 Subject: [PATCH 32/42] change: bump jinja2 to 3.1.3 in doc/requirments.txt (#4421) (#4423) * change: bump jinja2 to 3.1.3 in doc/requirments.txt (#4421) * change: bump jinja2 to 3.1.3 in doc/requirments.txt * Update requirements.txt * feature: TGI 1.4.0 (#4424) * documentation: fix the ClarifyCheckStep documentation to mention PDP (#4259) * documentation: fix the ClarifyCheckStep documentation to mention PDP support * fix: break the lines to meet pylint requirement --------- Co-authored-by: Shing Lyu * documentation: Explain the ClarifyCheckStep and QualityCheckStep parameters (#4261) * documentation: explain the ClarifyCheckStep and QualityCheckStep parameters * fix: remove trailing space --------- Co-authored-by: Shing Lyu * feat: Telemetry metrics (#4414) * Emit additional telemetry metrics * Fix unit tests * Emit endpoint failure to telemetry * Address PR Comments * Emit latency in telemetry * Address PR Comments * Addressed PR Comments * Address PR Comments * Fix tests * Fix integ tests --------- Co-authored-by: Jonathan Makunga Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> * documentation: change order of pipelines topics (#4427) * prepare release v2.208.0 * update development version to v2.208.1.dev0 * feature: AutoGluon 1.0.0 image_uris update (#4426) --------- Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> Co-authored-by: Jinyoung Lim Co-authored-by: Shing Lyu Co-authored-by: Shing Lyu Co-authored-by: Jonathan Makunga <54963715+makungaj1@users.noreply.github.com> Co-authored-by: Jonathan Makunga Co-authored-by: stacicho Co-authored-by: ci Co-authored-by: tonyhu --- CHANGELOG.md | 46 +++++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 699a44c787..a550961614 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,51 +4,51 @@ ### Features - * Update SM Python SDK for PT 2.2.0 SM DLC +- Update SM Python SDK for PT 2.2.0 SM DLC ### Bug Fixes and Other Changes - * Create custom tarfile extractall util to fix backward compatibility issue - * Upgrade smp to version 2.2 - * Enhance model builder selection logic to include model size +- Create custom tarfile extractall util to fix backward compatibility issue +- Upgrade smp to version 2.2 +- Enhance model builder selection logic to include model size ## v2.211.0 (2024-03-05) ### Features - * pin dll version to support python3.11 to the sdk - * instance specific jumpstart host requirements - * Add TensorFlow 2.14 image configs - * Add AutoMLV2 support - * Support selective pipeline execution between function step and regular step - * Add new Triton DLC URIs +- pin dll version to support python3.11 to the sdk +- instance specific jumpstart host requirements +- Add TensorFlow 2.14 image configs +- Add AutoMLV2 support +- Support selective pipeline execution between function step and regular step +- Add new Triton DLC URIs ### Bug Fixes and Other Changes - * Skip No Canvas regions for test_deploy_best_candidate - * make sure gpus are found in local_gpu run - * Bump Apache Airflow version to 2.8.2 - * properly close sagemaker config file after loading config - * remove enable_network_isolation from the python doc +- Skip No Canvas regions for test_deploy_best_candidate +- make sure gpus are found in local_gpu run +- Bump Apache Airflow version to 2.8.2 +- properly close sagemaker config file after loading config +- remove enable_network_isolation from the python doc ### Documentation Changes - * Add doc for new feature processor APIs and classes +- Add doc for new feature processor APIs and classes ## v2.210.0 (2024-02-28) ### Features - * Prepend SageMaker Studio App Type to boto3 User Agent string - * TGI optimum 0.0.18 (general+llm) - * TGI 1.4.2 +- Prepend SageMaker Studio App Type to boto3 User Agent string +- TGI optimum 0.0.18 (general+llm) +- TGI 1.4.2 ### Bug Fixes and Other Changes - * tolerate vulnerable old model for integ test and temporarily skip test_list_jumpstart_models_script_filter - * add missing regions to pytorch config - * Add validation for sagemaker version on remote job - * fixed implementation of fail_on_violation for transform with monitoring +- tolerate vulnerable old model for integ test and temporarily skip test_list_jumpstart_models_script_filter +- add missing regions to pytorch config +- Add validation for sagemaker version on remote job +- fixed implementation of fail_on_violation for transform with monitoring ## v2.209.0 (2024-02-24) From 29718b4061c15742eb4f3bea07b8f3448f0dff41 Mon Sep 17 00:00:00 2001 From: Ben Crabtree Date: Wed, 21 Feb 2024 10:58:01 -0500 Subject: [PATCH 33/42] feat: add hub and hubcontent support in retrieval function for jumpstart model cache (#4438) --- src/sagemaker/jumpstart/cache.py | 2 ++ src/sagemaker/jumpstart/constants.py | 3 +- src/sagemaker/jumpstart/types.py | 19 +++++++++++++ src/sagemaker/jumpstart/utils.py | 23 ++++++++++++++++ tests/unit/sagemaker/jumpstart/test_utils.py | 29 ++++++++++++++++++++ tests/unit/sagemaker/jumpstart/utils.py | 11 ++++---- 6 files changed, 80 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 63b47632fd..0d93ffb76b 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -57,6 +57,7 @@ DescribeHubContentsResponse, HubType, HubContentType, + HubDataType, ) from sagemaker.jumpstart.curated_hub import utils as hub_utils from sagemaker.jumpstart.enums import JumpStartModelType @@ -428,6 +429,7 @@ def _retrieval_function( """ data_type, id_info = key.data_type, key.id_info + if data_type in { JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, JumpStartS3FileType.PROPRIETARY_MANIFEST, diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 1b679d44f6..21412d65ea 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -172,7 +172,8 @@ JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json" -HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$" +# works cross-partition +HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$" HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py" diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index b70d93c67d..82bfc5c13f 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -972,6 +972,25 @@ def from_hub_content_doc(self, hub_content_doc: Dict[str, Any]) -> None: """ # TODO: Implement + def to_json(self) -> Dict[str, Any]: + """Returns json representation of JumpStartModelSpecs object.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + elif isinstance(cur_val, list): + json_obj[att] = [] + for obj in cur_val: + if issubclass(type(obj), JumpStartDataHolderType): + json_obj[att].append(obj.to_json()) + else: + json_obj[att].append(obj) + else: + json_obj[att] = cur_val + return json_obj + def supports_prepacked_inference(self) -> bool: """Returns True if the model has a prepacked inference artifact.""" return getattr(self, "hosting_prepacked_artifact_key", None) is not None diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 5ba6a3b98b..82b43e7d22 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -15,6 +15,7 @@ import logging import os from typing import Any, Dict, List, Set, Optional, Tuple, Union +import re from urllib.parse import urlparse import boto3 from packaging.version import Version @@ -866,3 +867,25 @@ def get_jumpstart_model_id_version_from_resource_arn( def generate_studio_spec_file_prefix(model_id: str, model_version: str) -> str: """Returns the Studio Spec file prefix given a model ID and version.""" return f"studio_models/{model_id}/studio_specs_v{model_version}.json" + +def extract_info_from_hub_content_arn( + arn: str, +) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: + """Extracts hub_name, content_name, and content_version from a HubContentArn""" + + match = re.match(constants.HUB_MODEL_ARN_REGEX, arn) + if match: + hub_name = match.group(4) + hub_region = match.group(2) + content_name = match.group(5) + content_version = match.group(6) + + return hub_name, hub_region, content_name, content_version + + match = re.match(constants.HUB_ARN_REGEX, arn) + if match: + hub_name = match.group(4) + hub_region = match.group(2) + return hub_name, hub_region, None, None + + return None, None, None, None diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index fa1a3fc72f..472b2dfdd9 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1214,6 +1214,35 @@ def test_mime_type_enum_from_str(): assert MIMEType.from_suffixed_type(mime_type_with_suffix) == mime_type +def test_extract_info_from_hub_content_arn(): + model_arn = ( + "arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Model/my-mock-model/1.0.2" + ) + assert utils.extract_info_from_hub_content_arn(model_arn) == ( + "MockHub", + "us-west-2", + "my-mock-model", + "1.0.2", + ) + + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub" + assert utils.extract_info_from_hub_content_arn(hub_arn) == ("MockHub", "us-west-2", None, None) + + invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123" + assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) + + invalid_arn = "nonsense-string" + assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) + + invalid_arn = "" + assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) + + invalid_arn = ( + "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MyHub/Notebook/my-notebook/1.0.0" + ) + assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) + + class TestIsValidModelId(TestCase): @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 0d1f6eb2d1..128e41a796 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -22,6 +22,7 @@ JUMPSTART_REGION_NAME_SET, ) from sagemaker.jumpstart.types import ( + HubDataType, JumpStartCachedContentKey, JumpStartCachedContentValue, JumpStartModelSpecs, @@ -218,12 +219,10 @@ def patched_retrieval_function( datatype, id_info = key.data_type, key.id_info if datatype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: - return JumpStartCachedContentValue( - formatted_content=get_formatted_manifest(BASE_MANIFEST) - ) + return JumpStartCachedContentValue(formatted_content=get_formatted_manifest(BASE_MANIFEST)) - if datatype == JumpStartCachedContentValue.OPEN_WEIGHT_SPECS: - _, model_id, specs_version = s3_key.split("/") + if datatype == JumpStartS3FileType.OPEN_WEIGHT_SPECS: + _, model_id, specs_version = id_info.split("/") version = specs_version.replace("specs_v", "").replace(".json", "") return JumpStartCachedContentValue( formatted_content=get_spec_from_base_spec(model_id=model_id, version=version) @@ -245,7 +244,7 @@ def patched_retrieval_function( ) if datatype == JumpStartS3FileType.PROPRIETARY_SPECS: - _, model_id, specs_version = s3_key.split("/") + _, model_id, specs_version = id_info.split("/") version = specs_version.replace("proprietary_specs_", "").replace(".json", "") return JumpStartCachedContentValue( formatted_content=get_spec_from_base_spec( From d7c4307d52a1dd701b61bd60088e6c661aeceff0 Mon Sep 17 00:00:00 2001 From: Ben Crabtree Date: Mon, 26 Feb 2024 14:18:18 -0500 Subject: [PATCH 34/42] feat: jsch jumpstart estimator support (#4439) --- src/sagemaker/jumpstart/accessors.py | 1 + src/sagemaker/jumpstart/cache.py | 1 + src/sagemaker/jumpstart/constants.py | 3 +- src/sagemaker/jumpstart/estimator.py | 1 + src/sagemaker/jumpstart/factory/estimator.py | 2 ++ src/sagemaker/jumpstart/factory/model.py | 2 ++ src/sagemaker/jumpstart/types.py | 8 +++++ src/sagemaker/jumpstart/utils.py | 2 +- .../jumpstart/test_validate.py | 2 ++ .../image_uris/jumpstart/test_common.py | 4 +++ .../jumpstart/test_instance_types.py | 2 +- .../jumpstart/curated_hub/test_utils.py | 3 +- .../sagemaker/jumpstart/test_accessors.py | 25 ++++++++++++++++ .../jumpstart/test_notebook_utils.py | 1 + tests/unit/sagemaker/jumpstart/test_utils.py | 29 ------------------- tests/unit/sagemaker/jumpstart/utils.py | 5 +++- .../model_uris/jumpstart/test_common.py | 4 +++ .../jumpstart/test_resource_requirements.py | 1 + .../script_uris/jumpstart/test_common.py | 4 +++ 19 files changed, 65 insertions(+), 35 deletions(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index dfc833ec28..c9f805c225 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -257,6 +257,7 @@ def get_model_specs( hub_arn: Optional[str] = None, s3_client: Optional[boto3.client] = None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn: Optional[str] = None, ) -> JumpStartModelSpecs: """Returns model specs from JumpStart models cache. diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 0d93ffb76b..75317b6784 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -39,6 +39,7 @@ get_wildcard_model_version_msg, get_wildcard_proprietary_model_version_msg, ) +from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg from sagemaker.jumpstart.parameters import ( JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 21412d65ea..1b679d44f6 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -172,8 +172,7 @@ JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json" -# works cross-partition -HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$" +HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$" HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py" diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 6406932924..4cdd540111 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -534,6 +534,7 @@ def _validate_model_id_and_get_type_hook(): model_version=model_version, hub_arn=hub_arn, model_type=self.model_type, + hub_arn=hub_arn, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, role=role, diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index fb598256fa..fa04b46a7c 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -81,6 +81,7 @@ def get_init_kwargs( model_version: Optional[str] = None, hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -140,6 +141,7 @@ def get_init_kwargs( model_version=model_version, hub_arn=hub_arn, model_type=model_type, + hub_arn=hub_arn, role=role, region=region, instance_count=instance_count, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 6f7a83cef1..273257088e 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -550,6 +550,7 @@ def get_deploy_kwargs( model_version: Optional[str] = None, hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -584,6 +585,7 @@ def get_deploy_kwargs( model_version=model_version, hub_arn=hub_arn, model_type=model_type, + hub_arn=hub_arn, region=region, initial_instance_count=initial_instance_count, instance_type=instance_type, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 82bfc5c13f..84fe68c0ef 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1420,6 +1420,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", "initial_instance_count", "instance_type", "region", @@ -1453,6 +1454,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1499,6 +1501,7 @@ def __init__( self.model_version = model_version self.hub_arn = hub_arn self.model_type = model_type + self.hub_arn = hub_arn self.initial_instance_count = initial_instance_count self.instance_type = instance_type self.region = region @@ -1535,6 +1538,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", "instance_type", "instance_count", "region", @@ -1596,6 +1600,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", } def __init__( @@ -1725,6 +1730,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", "region", "inputs", "wait", @@ -1741,6 +1747,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1769,6 +1776,7 @@ def __init__( self.model_version = model_version self.hub_arn = hub_arn self.model_type = model_type + self.hub_arn = hub_arn self.region = region self.inputs = inputs self.wait = wait diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 82b43e7d22..4fc8752625 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -14,8 +14,8 @@ from __future__ import absolute_import import logging import os -from typing import Any, Dict, List, Set, Optional, Tuple, Union import re +from typing import Any, Dict, List, Set, Optional, Tuple, Union from urllib.parse import urlparse import boto3 from packaging.version import Version diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index 0f69cb572a..93d7098870 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -453,6 +453,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -516,6 +517,7 @@ def test_jumpstart_validate_all_hyperparameters( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index bd4383499d..45af6faeed 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -56,6 +56,7 @@ def test_jumpstart_common_image_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -78,6 +79,7 @@ def test_jumpstart_common_image_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -100,6 +102,7 @@ def test_jumpstart_common_image_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -122,6 +125,7 @@ def test_jumpstart_common_image_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index ea6835bec3..05c4df28f9 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -123,9 +123,9 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode region=region, model_id=model_id, version=model_version, + model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py index 15b6d8fba3..0c2393b71a 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py @@ -14,8 +14,9 @@ from unittest.mock import Mock from sagemaker.jumpstart.types import HubArnExtractedInfo -from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.jumpstart.curated_hub import utils +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo def test_get_info_from_hub_resource_arn(): diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index fb22287909..6ab71c3f53 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -137,6 +137,31 @@ def test_jumpstart_proprietary_models_cache_get(mock_cache): > 0 ) +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") +def test_jumpstart_models_cache_get_model_specs(mock_cache): + mock_cache.get_specs = Mock() + mock_cache.get_hub_model = Mock() + model_id, version = "pytorch-ic-mobilenet-v2", "*" + region = "us-west-2" + + accessors.JumpStartModelsAccessor.get_model_specs( + region=region, model_id=model_id, version=version + ) + mock_cache.get_specs.assert_called_once_with(model_id=model_id, semantic_version_str=version) + mock_cache.get_hub_model.assert_not_called() + + accessors.JumpStartModelsAccessor.get_model_specs( + region=region, + model_id=model_id, + version=version, + hub_arn=f"arn:aws:sagemaker:{region}:123456789123:hub/my-mock-hub", + ) + mock_cache.get_hub_model.assert_called_once_with( + hub_model_arn=( + f"arn:aws:sagemaker:{region}:123456789123:hub-content/my-mock-hub/Model/{model_id}/{version}" + ) + ) + # necessary because accessors is a static module reload(accessors) diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index a5d1ee3ac2..ed7c870a0e 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -751,4 +751,5 @@ def test_get_model_url( s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 472b2dfdd9..fa1a3fc72f 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1214,35 +1214,6 @@ def test_mime_type_enum_from_str(): assert MIMEType.from_suffixed_type(mime_type_with_suffix) == mime_type -def test_extract_info_from_hub_content_arn(): - model_arn = ( - "arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Model/my-mock-model/1.0.2" - ) - assert utils.extract_info_from_hub_content_arn(model_arn) == ( - "MockHub", - "us-west-2", - "my-mock-model", - "1.0.2", - ) - - hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub" - assert utils.extract_info_from_hub_content_arn(hub_arn) == ("MockHub", "us-west-2", None, None) - - invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123" - assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) - - invalid_arn = "nonsense-string" - assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) - - invalid_arn = "" - assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) - - invalid_arn = ( - "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MyHub/Notebook/my-notebook/1.0.0" - ) - assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) - - class TestIsValidModelId(TestCase): @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 128e41a796..4e816b6b97 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -22,7 +22,7 @@ JUMPSTART_REGION_NAME_SET, ) from sagemaker.jumpstart.types import ( - HubDataType, + HubContentType, JumpStartCachedContentKey, JumpStartCachedContentValue, JumpStartModelSpecs, @@ -253,6 +253,9 @@ def patched_retrieval_function( model_type=JumpStartModelType.PROPRIETARY, ) ) + # TODO: Implement + if datatype == HubContentType.HUB: + return None raise ValueError(f"Bad value for datatype: {datatype}") diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 2bb327c26f..06587a2074 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -54,6 +54,7 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -73,6 +74,7 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -93,6 +95,7 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -113,6 +116,7 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index 2a4d913a75..b2e055dd3c 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -57,6 +57,7 @@ def test_jumpstart_resource_requirements( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 87364a16fc..14ad48082e 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -54,6 +54,7 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -73,6 +74,7 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -93,6 +95,7 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -113,6 +116,7 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() From 92e35c8cb817990cd67ee862ced2051bcb48444b Mon Sep 17 00:00:00 2001 From: Ben Crabtree Date: Wed, 28 Feb 2024 17:09:01 -0500 Subject: [PATCH 35/42] Master jumpstart curated hub (#4464) --- src/sagemaker/image_uri_config/pytorch.json | 4277 ++++++++--------- .../bootstrap_runtime_environment.py | 3 + .../runtime_environment_manager.py | 2 + src/sagemaker/serve/schema/task.json | 130 +- src/sagemaker/utils.py | 1 + 5 files changed, 2143 insertions(+), 2270 deletions(-) diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index 85b454ebed..f399754c00 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -1,2216 +1,2083 @@ { - "eia": { - "processors": [ - "cpu" - ], - "version_aliases": { - "1.3": "1.3.1", - "1.5": "1.5.1" + "eia": { + "processors": ["cpu"], + "version_aliases": { + "1.3": "1.3.1", + "1.5": "1.5.1" + }, + "versions": { + "1.3.1": { + "py_versions": ["py3"], + "registries": { + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-central-2": "380420809688", + "eu-west-1": "763104351884", + "eu-south-2": "503227376785", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" }, - "versions": { - "1.3.1": { - "py_versions": [ - "py3" - ], - "registries": { - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-north-1": "763104351884", - "eu-central-2": "380420809688", - "eu-west-1": "763104351884", - "eu-south-2": "503227376785", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference-eia" - }, - "1.5.1": { - "py_versions": [ - "py3" - ], - "registries": { - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "eu-central-2": "380420809688", - "eu-west-1": "763104351884", - "eu-south-2": "503227376785", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference-eia" - } - } + "repository": "pytorch-inference-eia" + }, + "1.5.1": { + "py_versions": ["py3"], + "registries": { + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "eu-central-2": "380420809688", + "eu-west-1": "763104351884", + "eu-south-2": "503227376785", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference-eia" + } + } + }, + "inference": { + "processors": ["cpu", "gpu"], + "version_aliases": { + "0.4": "0.4.0", + "1.0": "1.0.0", + "1.1": "1.1.0", + "1.2": "1.2.0", + "1.3": "1.3.1", + "1.4": "1.4.0", + "1.5": "1.5.0", + "1.6": "1.6.0", + "1.7": "1.7.1", + "1.8": "1.8.1", + "1.9": "1.9.1", + "1.10": "1.10.2", + "1.11": "1.11.0", + "1.12": "1.12.1", + "1.13": "1.13.1", + "2.0": "2.0.1", + "2.1": "2.1.0" }, - "inference": { - "processors": [ - "cpu", - "gpu" - ], - "version_aliases": { - "0.4": "0.4.0", - "1.0": "1.0.0", - "1.1": "1.1.0", - "1.2": "1.2.0", - "1.3": "1.3.1", - "1.4": "1.4.0", - "1.5": "1.5.0", - "1.6": "1.6.0", - "1.7": "1.7.1", - "1.8": "1.8.1", - "1.9": "1.9.1", - "1.10": "1.10.2", - "1.11": "1.11.0", - "1.12": "1.12.1", - "1.13": "1.13.1", - "2.0": "2.0.1", - "2.1": "2.1.0" + "versions": { + "0.4.0": { + "py_versions": ["py2", "py3"], + "registries": { + "af-south-1": "313743910680", + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-south-1": "048378556238", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-isob-east-1": "453391408702", + "us-west-1": "520713654638", + "us-west-2": "520713654638" }, - "versions": { - "0.4.0": { - "py_versions": [ - "py2", - "py3" - ], - "registries": { - "af-south-1": "313743910680", - "ap-east-1": "057415533634", - "ap-northeast-1": "520713654638", - "ap-northeast-2": "520713654638", - "ap-south-1": "520713654638", - "ap-southeast-1": "520713654638", - "ap-southeast-2": "520713654638", - "ca-central-1": "520713654638", - "cn-north-1": "422961961927", - "cn-northwest-1": "423003514399", - "eu-central-1": "520713654638", - "eu-north-1": "520713654638", - "eu-south-1": "048378556238", - "eu-west-1": "520713654638", - "eu-west-2": "520713654638", - "eu-west-3": "520713654638", - "me-south-1": "724002660598", - "sa-east-1": "520713654638", - "us-east-1": "520713654638", - "us-east-2": "520713654638", - "us-gov-west-1": "246785580436", - "us-iso-east-1": "744548109606", - "us-isob-east-1": "453391408702", - "us-west-1": "520713654638", - "us-west-2": "520713654638" - }, - "repository": "sagemaker-pytorch" - }, - "1.0.0": { - "py_versions": [ - "py2", - "py3" - ], - "registries": { - "af-south-1": "313743910680", - "ap-east-1": "057415533634", - "ap-northeast-1": "520713654638", - "ap-northeast-2": "520713654638", - "ap-south-1": "520713654638", - "ap-southeast-1": "520713654638", - "ap-southeast-2": "520713654638", - "ca-central-1": "520713654638", - "cn-north-1": "422961961927", - "cn-northwest-1": "423003514399", - "eu-central-1": "520713654638", - "eu-north-1": "520713654638", - "eu-south-1": "048378556238", - "eu-west-1": "520713654638", - "eu-west-2": "520713654638", - "eu-west-3": "520713654638", - "me-south-1": "724002660598", - "sa-east-1": "520713654638", - "us-east-1": "520713654638", - "us-east-2": "520713654638", - "us-gov-west-1": "246785580436", - "us-iso-east-1": "744548109606", - "us-isob-east-1": "453391408702", - "us-west-1": "520713654638", - "us-west-2": "520713654638" - }, - "repository": "sagemaker-pytorch" - }, - "1.1.0": { - "py_versions": [ - "py2", - "py3" - ], - "registries": { - "af-south-1": "313743910680", - "ap-east-1": "057415533634", - "ap-northeast-1": "520713654638", - "ap-northeast-2": "520713654638", - "ap-south-1": "520713654638", - "ap-southeast-1": "520713654638", - "ap-southeast-2": "520713654638", - "ca-central-1": "520713654638", - "cn-north-1": "422961961927", - "cn-northwest-1": "423003514399", - "eu-central-1": "520713654638", - "eu-north-1": "520713654638", - "eu-south-1": "048378556238", - "eu-west-1": "520713654638", - "eu-west-2": "520713654638", - "eu-west-3": "520713654638", - "me-south-1": "724002660598", - "sa-east-1": "520713654638", - "us-east-1": "520713654638", - "us-east-2": "520713654638", - "us-gov-west-1": "246785580436", - "us-iso-east-1": "744548109606", - "us-isob-east-1": "453391408702", - "us-west-1": "520713654638", - "us-west-2": "520713654638" - }, - "repository": "sagemaker-pytorch" - }, - "1.2.0": { - "py_versions": [ - "py2", - "py3" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference" - }, - "1.3.1": { - "py_versions": [ - "py2", - "py3" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference" - }, - "1.4.0": { - "py_versions": [ - "py3" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference" - }, - "1.5.0": { - "py_versions": [ - "py3" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference" - }, - "1.6.0": { - "py_versions": [ - "py3", - "py36" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference" - }, - "1.7.1": { - "py_versions": [ - "py3", - "py36" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference" - }, - "1.8.0": { - "py_versions": [ - "py3", - "py36" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference" - }, - "1.8.1": { - "py_versions": [ - "py3", - "py36" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference" - }, - "1.9.0": { - "py_versions": [ - "py38" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference" - }, - "1.9.1": { - "py_versions": [ - "py38" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference" - }, - "1.10.0": { - "py_versions": [ - "py38" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference" - }, - "1.10.2": { - "py_versions": [ - "py38" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference" - }, - "1.11.0": { - "py_versions": [ - "py38" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference" - }, - "1.12.0": { - "py_versions": [ - "py38" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference" - }, - "1.12.1": { - "py_versions": [ - "py38" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference" - }, - "1.13.1": { - "py_versions": [ - "py39" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference" - }, - "2.0.0": { - "py_versions": [ - "py310" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference" - }, - "2.0.1": { - "py_versions": [ - "py310" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference" - }, - "2.1.0": { - "py_versions": [ - "py310" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference" - }, - "2.2.0": { - "py_versions": [ - "py310" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference" - } - } + "repository": "sagemaker-pytorch" + }, + "1.0.0": { + "py_versions": ["py2", "py3"], + "registries": { + "af-south-1": "313743910680", + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-south-1": "048378556238", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-isob-east-1": "453391408702", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-pytorch" + }, + "1.1.0": { + "py_versions": ["py2", "py3"], + "registries": { + "af-south-1": "313743910680", + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-south-1": "048378556238", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-isob-east-1": "453391408702", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-pytorch" + }, + "1.2.0": { + "py_versions": ["py2", "py3"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + }, + "1.3.1": { + "py_versions": ["py2", "py3"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + }, + "1.4.0": { + "py_versions": ["py3"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + }, + "1.5.0": { + "py_versions": ["py3"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + }, + "1.6.0": { + "py_versions": ["py3", "py36"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + }, + "1.7.1": { + "py_versions": ["py3", "py36"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + }, + "1.8.0": { + "py_versions": ["py3", "py36"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + }, + "1.8.1": { + "py_versions": ["py3", "py36"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + }, + "1.9.0": { + "py_versions": ["py38"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + }, + "1.9.1": { + "py_versions": ["py38"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + }, + "1.10.0": { + "py_versions": ["py38"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + }, + "1.10.2": { + "py_versions": ["py38"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + }, + "1.11.0": { + "py_versions": ["py38"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + }, + "1.12.0": { + "py_versions": ["py38"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + }, + "1.12.1": { + "py_versions": ["py38"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + }, + "1.13.1": { + "py_versions": ["py39"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + }, + "2.0.0": { + "py_versions": ["py310"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + }, + "2.0.1": { + "py_versions": ["py310"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + }, + "2.1.0": { + "py_versions": ["py310"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + }, + "2.2.0": { + "py_versions": ["py310"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference" + } + } + }, + "inference_graviton": { + "processors": ["cpu"], + "version_aliases": { + "1.12": "1.12.1", + "2.0": "2.0.1", + "2.1": "2.1.0" }, - "inference_graviton": { - "processors": [ - "cpu" - ], - "version_aliases": { - "1.12": "1.12.1", - "2.0": "2.0.1", - "2.1": "2.1.0" + "versions": { + "1.12.1": { + "py_versions": ["py38"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" }, - "versions": { - "1.12.1": { - "py_versions": [ - "py38" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference-graviton", - "container_version": { - "cpu": "ubuntu20.04" - } - }, - "2.0.0": { - "py_versions": [ - "py310" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference-graviton", - "container_version": { - "cpu": "ubuntu20.04" - } - }, - "2.0.1": { - "py_versions": [ - "py310" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference-graviton", - "container_version": { - "cpu": "ubuntu20.04" - } - }, - "2.1.0": { - "py_versions": [ - "py310" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-inference-graviton", - "container_version": { - "cpu": "ubuntu20.04" - } - } + "repository": "pytorch-inference-graviton", + "container_version": { + "cpu": "ubuntu20.04" } - }, - "training": { - "processors": [ - "cpu", - "gpu" - ], - "version_aliases": { - "0.4": "0.4.0", - "1.0": "1.0.0", - "1.1": "1.1.0", - "1.2": "1.2.0", - "1.3": "1.3.1", - "1.4": "1.4.0", - "1.5": "1.5.0", - "1.6": "1.6.0", - "1.7": "1.7.1", - "1.8": "1.8.1", - "1.9": "1.9.1", - "1.10": "1.10.2", - "1.11": "1.11.0", - "1.12": "1.12.1", - "1.13": "1.13.1", - "2.0": "2.0.1", - "2.1": "2.1.0", - "2.2": "2.2.0" + }, + "2.0.0": { + "py_versions": ["py310"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" }, - "versions": { - "0.4.0": { - "py_versions": [ - "py2", - "py3" - ], - "registries": { - "af-south-1": "313743910680", - "ap-east-1": "057415533634", - "ap-northeast-1": "520713654638", - "ap-northeast-2": "520713654638", - "ap-south-1": "520713654638", - "ap-southeast-1": "520713654638", - "ap-southeast-2": "520713654638", - "ca-central-1": "520713654638", - "cn-north-1": "422961961927", - "cn-northwest-1": "423003514399", - "eu-central-1": "520713654638", - "eu-north-1": "520713654638", - "eu-south-1": "048378556238", - "eu-west-1": "520713654638", - "eu-west-2": "520713654638", - "eu-west-3": "520713654638", - "me-south-1": "724002660598", - "sa-east-1": "520713654638", - "us-east-1": "520713654638", - "us-east-2": "520713654638", - "us-gov-west-1": "246785580436", - "us-iso-east-1": "744548109606", - "us-isob-east-1": "453391408702", - "us-west-1": "520713654638", - "us-west-2": "520713654638" - }, - "repository": "sagemaker-pytorch" - }, - "1.0.0": { - "py_versions": [ - "py2", - "py3" - ], - "registries": { - "af-south-1": "313743910680", - "ap-east-1": "057415533634", - "ap-northeast-1": "520713654638", - "ap-northeast-2": "520713654638", - "ap-south-1": "520713654638", - "ap-southeast-1": "520713654638", - "ap-southeast-2": "520713654638", - "ca-central-1": "520713654638", - "cn-north-1": "422961961927", - "cn-northwest-1": "423003514399", - "eu-central-1": "520713654638", - "eu-north-1": "520713654638", - "eu-south-1": "048378556238", - "eu-west-1": "520713654638", - "eu-west-2": "520713654638", - "eu-west-3": "520713654638", - "me-south-1": "724002660598", - "sa-east-1": "520713654638", - "us-east-1": "520713654638", - "us-east-2": "520713654638", - "us-gov-west-1": "246785580436", - "us-iso-east-1": "744548109606", - "us-isob-east-1": "453391408702", - "us-west-1": "520713654638", - "us-west-2": "520713654638" - }, - "repository": "sagemaker-pytorch" - }, - "1.1.0": { - "py_versions": [ - "py2", - "py3" - ], - "registries": { - "af-south-1": "313743910680", - "ap-east-1": "057415533634", - "ap-northeast-1": "520713654638", - "ap-northeast-2": "520713654638", - "ap-south-1": "520713654638", - "ap-southeast-1": "520713654638", - "ap-southeast-2": "520713654638", - "ca-central-1": "520713654638", - "cn-north-1": "422961961927", - "cn-northwest-1": "423003514399", - "eu-central-1": "520713654638", - "eu-north-1": "520713654638", - "eu-south-1": "048378556238", - "eu-west-1": "520713654638", - "eu-west-2": "520713654638", - "eu-west-3": "520713654638", - "me-south-1": "724002660598", - "sa-east-1": "520713654638", - "us-east-1": "520713654638", - "us-east-2": "520713654638", - "us-gov-west-1": "246785580436", - "us-iso-east-1": "744548109606", - "us-isob-east-1": "453391408702", - "us-west-1": "520713654638", - "us-west-2": "520713654638" - }, - "repository": "sagemaker-pytorch" - }, - "1.2.0": { - "py_versions": [ - "py2", - "py3" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-training" - }, - "1.3.1": { - "py_versions": [ - "py2", - "py3" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-training" - }, - "1.4.0": { - "py_versions": [ - "py2", - "py3" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-training" - }, - "1.5.0": { - "py_versions": [ - "py3" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-training" - }, - "1.6.0": { - "py_versions": [ - "py3", - "py36" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-training" - }, - "1.7.1": { - "py_versions": [ - "py3", - "py36" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-training" - }, - "1.8.0": { - "py_versions": [ - "py3", - "py36" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-training" - }, - "1.8.1": { - "py_versions": [ - "py3", - "py36" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-training" - }, - "1.9.0": { - "py_versions": [ - "py38" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-training" - }, - "1.9.1": { - "py_versions": [ - "py38" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-training" - }, - "1.10.0": { - "py_versions": [ - "py38" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-training" - }, - "1.10.2": { - "py_versions": [ - "py38" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-training" - }, - "1.11.0": { - "py_versions": [ - "py38" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-training" - }, - "1.12.0": { - "py_versions": [ - "py38" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-training" - }, - "1.12.1": { - "py_versions": [ - "py38" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-training" - }, - "1.13.1": { - "py_versions": [ - "py39" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-training" - }, - "2.0.0": { - "py_versions": [ - "py310" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-training" - }, - "2.0.1": { - "py_versions": [ - "py310" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-training" - }, - "2.1.0": { - "py_versions": [ - "py310" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-training" - }, - "2.2.0": { - "py_versions": [ - "py310" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "pytorch-training" - } + "repository": "pytorch-inference-graviton", + "container_version": { + "cpu": "ubuntu20.04" } + }, + "2.0.1": { + "py_versions": ["py310"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference-graviton", + "container_version": { + "cpu": "ubuntu20.04" + } + }, + "2.1.0": { + "py_versions": ["py310"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-inference-graviton", + "container_version": { + "cpu": "ubuntu20.04" + } + } + } + }, + "training": { + "processors": ["cpu", "gpu"], + "version_aliases": { + "0.4": "0.4.0", + "1.0": "1.0.0", + "1.1": "1.1.0", + "1.2": "1.2.0", + "1.3": "1.3.1", + "1.4": "1.4.0", + "1.5": "1.5.0", + "1.6": "1.6.0", + "1.7": "1.7.1", + "1.8": "1.8.1", + "1.9": "1.9.1", + "1.10": "1.10.2", + "1.11": "1.11.0", + "1.12": "1.12.1", + "1.13": "1.13.1", + "2.0": "2.0.1", + "2.1": "2.1.0", + "2.2": "2.2.0" + }, + "versions": { + "0.4.0": { + "py_versions": ["py2", "py3"], + "registries": { + "af-south-1": "313743910680", + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-south-1": "048378556238", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-isob-east-1": "453391408702", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-pytorch" + }, + "1.0.0": { + "py_versions": ["py2", "py3"], + "registries": { + "af-south-1": "313743910680", + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-south-1": "048378556238", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-isob-east-1": "453391408702", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-pytorch" + }, + "1.1.0": { + "py_versions": ["py2", "py3"], + "registries": { + "af-south-1": "313743910680", + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-south-1": "048378556238", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-isob-east-1": "453391408702", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-pytorch" + }, + "1.2.0": { + "py_versions": ["py2", "py3"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" + }, + "1.3.1": { + "py_versions": ["py2", "py3"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" + }, + "1.4.0": { + "py_versions": ["py2", "py3"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" + }, + "1.5.0": { + "py_versions": ["py3"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" + }, + "1.6.0": { + "py_versions": ["py3", "py36"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" + }, + "1.7.1": { + "py_versions": ["py3", "py36"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" + }, + "1.8.0": { + "py_versions": ["py3", "py36"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" + }, + "1.8.1": { + "py_versions": ["py3", "py36"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" + }, + "1.9.0": { + "py_versions": ["py38"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" + }, + "1.9.1": { + "py_versions": ["py38"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" + }, + "1.10.0": { + "py_versions": ["py38"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" + }, + "1.10.2": { + "py_versions": ["py38"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" + }, + "1.11.0": { + "py_versions": ["py38"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" + }, + "1.12.0": { + "py_versions": ["py38"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" + }, + "1.12.1": { + "py_versions": ["py38"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" + }, + "1.13.1": { + "py_versions": ["py39"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" + }, + "2.0.0": { + "py_versions": ["py310"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" + }, + "2.0.1": { + "py_versions": ["py310"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" + }, + "2.1.0": { + "py_versions": ["py310"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" + }, + "2.2.0": { + "py_versions": ["py310"], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "pytorch-training" + } } + } } diff --git a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py index 8fd83bfcfe..5332f7bdd0 100644 --- a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -65,6 +65,9 @@ def main(sys_args=None): conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env) + RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( + client_sagemaker_pysdk_version + ) user = getpass.getuser() if user != "root": diff --git a/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py b/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py index 13493c1d15..64e6c087f8 100644 --- a/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py +++ b/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py @@ -24,6 +24,8 @@ import dataclasses import json +import sagemaker + class _UTCFormatter(logging.Formatter): """Class that overrides the default local time provider in log formatter.""" diff --git a/src/sagemaker/serve/schema/task.json b/src/sagemaker/serve/schema/task.json index c897f4abec..ef142b1b37 100644 --- a/src/sagemaker/serve/schema/task.json +++ b/src/sagemaker/serve/schema/task.json @@ -1,67 +1,67 @@ { - "fill-mask": { - "sample_inputs": { - "properties": { - "inputs": "Paris is the [MASK] of France.", - "parameters": {} - } - }, - "sample_outputs": { - "properties": [ - { - "sequence": "Paris is the capital of France.", - "score": 0.7 - } - ] - } - }, - "question-answering": { - "sample_inputs": { - "properties": { - "context": "I have a German Shepherd dog, named Coco.", - "question": "What is my dog's breed?" - } - }, - "sample_outputs": { - "properties": [ - { - "answer": "German Shepherd", - "score": 0.972, - "start": 9, - "end": 24 - } - ] - } - }, - "text-classification": { - "sample_inputs": { - "properties": { - "inputs": "Where is the capital of France?, Paris is the capital of France.", - "parameters": {} - } - }, - "sample_outputs": { - "properties": [ - { - "label": "entailment", - "score": 0.997 - } - ] - } - }, - "text-generation": { - "sample_inputs": { - "properties": { - "inputs": "Hello, I'm a language model", - "parameters": {} - } - }, - "sample_outputs": { - "properties": [ - { - "generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my" - } - ] - } - } + "fill-mask": { + "sample_inputs": { + "properties": { + "inputs": "Paris is the [MASK] of France.", + "parameters": {} + } + }, + "sample_outputs": { + "properties": [ + { + "sequence": "Paris is the capital of France.", + "score": 0.7 + } + ] + } + }, + "question-answering": { + "sample_inputs": { + "properties": { + "context": "I have a German Shepherd dog, named Coco.", + "question": "What is my dog's breed?" + } + }, + "sample_outputs": { + "properties": [ + { + "answer": "German Shepherd", + "score": 0.972, + "start": 9, + "end": 24 + } + ] + } + }, + "text-classification": { + "sample_inputs": { + "properties": { + "inputs": "Where is the capital of France?, Paris is the capital of France.", + "parameters": {} + } + }, + "sample_outputs": { + "properties": [ + { + "label": "entailment", + "score": 0.997 + } + ] + } + }, + "text-generation": { + "sample_inputs": { + "properties": { + "inputs": "Hello, I'm a language model", + "parameters": {} + } + }, + "sample_outputs": { + "properties": [ + { + "generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my" + } + ] + } + } } diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 7896aac150..fe8c0b7c56 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -22,6 +22,7 @@ import random import re import shutil +import sys import tarfile import tempfile import time From d2f72a2816ca7532f18957c5c74f351e78f8e471 Mon Sep 17 00:00:00 2001 From: Ben Crabtree Date: Wed, 28 Feb 2024 17:15:59 -0500 Subject: [PATCH 36/42] add hub_arn support for accept_types, content_types, serializers, deserializers, and predictor (#4463) --- src/sagemaker/jumpstart/factory/model.py | 2 ++ src/sagemaker/jumpstart/types.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 273257088e..2fa538f33b 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -726,6 +726,7 @@ def get_init_kwargs( model_version: Optional[str] = None, hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, instance_type: Optional[str] = None, @@ -759,6 +760,7 @@ def get_init_kwargs( model_version=model_version, hub_arn=hub_arn, model_type=model_type, + hub_arn=hub_arn, instance_type=instance_type, region=region, image_uri=image_uri, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 84fe68c0ef..e2fd0e280c 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1310,6 +1310,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", "instance_type", "tolerate_vulnerable_model", "tolerate_deprecated_model", @@ -1342,6 +1343,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", "tolerate_vulnerable_model", "tolerate_deprecated_model", "region", @@ -1355,6 +1357,7 @@ def __init__( model_version: Optional[str] = None, hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, + hub_arn: Optional[str] = None, region: Optional[str] = None, instance_type: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, @@ -1386,6 +1389,7 @@ def __init__( self.model_version = model_version self.hub_arn = hub_arn self.model_type = model_type + self.hub_arn = hub_arn self.instance_type = instance_type self.region = region self.image_uri = image_uri From 52eae82fda78ae561123445087f553e24908f726 Mon Sep 17 00:00:00 2001 From: Jinyoung Lim Date: Thu, 29 Feb 2024 08:42:47 -0800 Subject: [PATCH 37/42] feature: JumpStart CuratedHub class creation and function definitions (#4448) --- src/sagemaker/jumpstart/cache.py | 2 +- src/sagemaker/jumpstart/types.py | 19 ----------- .../jumpstart/curated_hub/test_utils.py | 32 +++++++++++++++++++ tests/unit/sagemaker/jumpstart/test_cache.py | 3 +- tests/unit/sagemaker/jumpstart/utils.py | 4 +-- 5 files changed, 37 insertions(+), 23 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 75317b6784..11e8dc1792 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -34,6 +34,7 @@ DEFAULT_JUMPSTART_SAGEMAKER_SESSION, MODEL_TYPE_TO_MANIFEST_MAP, MODEL_TYPE_TO_SPECS_MAP, + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) from sagemaker.jumpstart.exceptions import ( get_wildcard_model_version_msg, @@ -477,7 +478,6 @@ def _retrieval_function( return JumpStartCachedContentValue( formatted_content=model_specs ) - if data_type == HubType.HUB: hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info) response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index e2fd0e280c..4b3b03283d 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -972,25 +972,6 @@ def from_hub_content_doc(self, hub_content_doc: Dict[str, Any]) -> None: """ # TODO: Implement - def to_json(self) -> Dict[str, Any]: - """Returns json representation of JumpStartModelSpecs object.""" - json_obj = {} - for att in self.__slots__: - if hasattr(self, att): - cur_val = getattr(self, att) - if issubclass(type(cur_val), JumpStartDataHolderType): - json_obj[att] = cur_val.to_json() - elif isinstance(cur_val, list): - json_obj[att] = [] - for obj in cur_val: - if issubclass(type(obj), JumpStartDataHolderType): - json_obj[att].append(obj.to_json()) - else: - json_obj[att].append(obj) - else: - json_obj[att] = cur_val - return json_obj - def supports_prepacked_inference(self) -> bool: """Returns True if the model has a prepacked inference artifact.""" return getattr(self, "hosting_prepacked_artifact_key", None) is not None diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py index 0c2393b71a..d1ccd7d1c7 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py @@ -140,6 +140,38 @@ def test_generate_hub_arn_for_init_kwargs(): utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn ) + assert ( + utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session) + == hub_arn + ) + + +def test_generate_default_hub_bucket_name(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.boto_region_name = "us-east-1" + + assert ( + utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session) + == "sagemaker-hubs-us-east-1-123456789123" + ) + + +def test_create_hub_bucket_if_it_does_not_exist(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.client("sts").get_caller_identity.return_value = { + "Account": "123456789123" + } + mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None + mock_sagemaker_session.boto_region_name = "us-east-1" + bucket_name = "sagemaker-hubs-us-east-1-123456789123" + created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( + sagemaker_session=mock_sagemaker_session + ) + + mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() + assert created_hub_bucket_name == bucket_name assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 348e93e7b7..5fac8a5c8d 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -28,6 +28,7 @@ JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, JumpStartModelsCache, ) +from sagemaker.session_settings import SessionSettings from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, @@ -1133,7 +1134,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") assert mocked_is_dir.call_count == 2 - mocked_open.assert_not_called() + assert mocked_open.call_count == 2 mocked_get_json_file_and_etag_from_s3.assert_has_calls( calls=[ call("models_manifest.json"), diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 4e816b6b97..dc96636e9b 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -22,7 +22,6 @@ JUMPSTART_REGION_NAME_SET, ) from sagemaker.jumpstart.types import ( - HubContentType, JumpStartCachedContentKey, JumpStartCachedContentValue, JumpStartModelSpecs, @@ -32,6 +31,7 @@ HubContentType, ) from sagemaker.jumpstart.enums import JumpStartModelType + from sagemaker.jumpstart.utils import get_formatted_manifest from tests.unit.sagemaker.jumpstart.constants import ( PROTOTYPICAL_MODEL_SPECS_DICT, @@ -254,7 +254,7 @@ def patched_retrieval_function( ) ) # TODO: Implement - if datatype == HubContentType.HUB: + if datatype == HubType.HUB: return None raise ValueError(f"Bad value for datatype: {datatype}") From c4501684b30062e3a0d2770df2be6cd8696d0c50 Mon Sep 17 00:00:00 2001 From: Ben Crabtree Date: Tue, 12 Mar 2024 13:15:58 -0400 Subject: [PATCH 38/42] MultiPartCopy with Sync Algorithm (#4475) * first pass at sync function with util classes * adding tests and update clases * linting * file generator class inheritance * lint * multipart copy and algorithm updates * modularize sync * reformatting folders * testing for sync * do not tolerate vulnerable * remove prints * handle multithreading progress bar * update tests * optimize function and add hub bucket prefix * docstrings and linting --- src/sagemaker/jumpstart/cache.py | 1 - .../sagemaker/jumpstart/curated_hub/test_utils.py | 11 +++++++---- tests/unit/sagemaker/jumpstart/test_cache.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 11e8dc1792..7879333437 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -486,7 +486,6 @@ def _retrieval_function( formatted_content=DescribeHubResponse(hub_description) ) - raise ValueError( self._file_type_error_msg(data_type) ) diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py index d1ccd7d1c7..2e7ec017bf 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py @@ -140,10 +140,7 @@ def test_generate_hub_arn_for_init_kwargs(): utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn ) - assert ( - utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session) - == hub_arn - ) + assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn def test_generate_default_hub_bucket_name(): @@ -163,8 +160,14 @@ def test_create_hub_bucket_if_it_does_not_exist(): mock_sagemaker_session.client("sts").get_caller_identity.return_value = { "Account": "123456789123" } + hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub" + # Mock custom session with custom values + mock_custom_session = Mock() + mock_custom_session.account_id.return_value = "000000000000" + mock_custom_session.boto_region_name = "us-east-2" mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None mock_sagemaker_session.boto_region_name = "us-east-1" + bucket_name = "sagemaker-hubs-us-east-1-123456789123" created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( sagemaker_session=mock_sagemaker_session diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 5fac8a5c8d..a9c34954ac 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -1134,7 +1134,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") assert mocked_is_dir.call_count == 2 - assert mocked_open.call_count == 2 + mocked_open.assert_not_called() mocked_get_json_file_and_etag_from_s3.assert_has_calls( calls=[ call("models_manifest.json"), From da1b6427363238a54431cb0d42e0297fdef35d40 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Wed, 13 Mar 2024 16:06:53 +0000 Subject: [PATCH 39/42] rebase with master --- src/sagemaker/jumpstart/cache.py | 6 ++++-- src/sagemaker/jumpstart/types.py | 1 - .../runtime_environment/runtime_environment_manager.py | 2 -- src/sagemaker/utils.py | 1 - .../instance_types/jumpstart/test_instance_types.py | 4 ++++ tests/unit/sagemaker/jumpstart/estimator/test_estimator.py | 6 +++--- tests/unit/sagemaker/jumpstart/test_accessors.py | 5 ++++- tests/unit/sagemaker/jumpstart/test_cache.py | 5 ++--- tests/unit/sagemaker/jumpstart/utils.py | 7 +++++++ 9 files changed, 24 insertions(+), 13 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 7879333437..d831d3023b 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -40,7 +40,6 @@ get_wildcard_model_version_msg, get_wildcard_proprietary_model_version_msg, ) -from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg from sagemaker.jumpstart.parameters import ( JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, @@ -453,7 +452,9 @@ def _retrieval_function( formatted_body, _ = self._get_json_file(id_info, data_type) model_specs = JumpStartModelSpecs(formatted_body) utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client) - return JumpStartCachedContentValue(formatted_content=model_specs) + return JumpStartCachedContentValue( + formatted_content=model_specs + ) if data_type == HubContentType.MODEL: hub_name, _, model_name, model_version = hub_utils.get_info_from_hub_resource_arn( @@ -478,6 +479,7 @@ def _retrieval_function( return JumpStartCachedContentValue( formatted_content=model_specs ) + if data_type == HubType.HUB: hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info) response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 4b3b03283d..485dc0b6b9 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -15,7 +15,6 @@ from copy import deepcopy from enum import Enum from typing import Any, Dict, List, Optional, Set, Union -from sagemaker.session import Session from sagemaker.utils import get_instance_type_family, format_tags, Tags from sagemaker.enums import EndpointType from sagemaker.model_metrics import ModelMetrics diff --git a/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py b/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py index 64e6c087f8..13493c1d15 100644 --- a/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py +++ b/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py @@ -24,8 +24,6 @@ import dataclasses import json -import sagemaker - class _UTCFormatter(logging.Formatter): """Class that overrides the default local time provider in log formatter.""" diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index fe8c0b7c56..7896aac150 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -22,7 +22,6 @@ import random import re import shutil -import sys import tarfile import tempfile import time diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index 05c4df28f9..f3454ca322 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -206,6 +206,7 @@ def test_jumpstart_instance_types_from_hub(patched_get_model_specs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -227,6 +228,7 @@ def test_jumpstart_instance_types_from_hub(patched_get_model_specs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -253,6 +255,7 @@ def test_jumpstart_instance_types_from_hub(patched_get_model_specs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -282,6 +285,7 @@ def test_jumpstart_instance_types_from_hub(patched_get_model_specs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index e0e3348cc2..e0ecce3398 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -292,7 +292,7 @@ def test_prepacked( @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_deploy_kwargs") @mock.patch("sagemaker.jumpstart.factory.estimator._retrieve_estimator_fit_kwargs") @mock.patch("sagemaker.jumpstart.curated_hub.utils.construct_hub_arn_from_name") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -309,7 +309,7 @@ def test_hub_model( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_construct_hub_arn_from_name: mock.Mock, mock_retrieve_estimator_fit_kwargs: mock.Mock, mock_retrieve_model_deploy_kwargs: mock.Mock, @@ -321,7 +321,7 @@ def test_hub_model( mock_get_caller_identity.return_value = "123456789123" mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "pytorch-hub-model-1", "*" hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index 6ab71c3f53..5d527dd5a1 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -137,6 +137,7 @@ def test_jumpstart_proprietary_models_cache_get(mock_cache): > 0 ) + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") def test_jumpstart_models_cache_get_model_specs(mock_cache): mock_cache.get_specs = Mock() @@ -147,7 +148,9 @@ def test_jumpstart_models_cache_get_model_specs(mock_cache): accessors.JumpStartModelsAccessor.get_model_specs( region=region, model_id=model_id, version=version ) - mock_cache.get_specs.assert_called_once_with(model_id=model_id, semantic_version_str=version) + mock_cache.get_specs.assert_called_once_with( + model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS + ) mock_cache.get_hub_model.assert_not_called() accessors.JumpStartModelsAccessor.get_model_specs( diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index a9c34954ac..d5537712a0 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -28,7 +28,6 @@ JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, JumpStartModelsCache, ) -from sagemaker.session_settings import SessionSettings from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, @@ -559,8 +558,8 @@ def test_jumpstart_proprietary_cache_accepts_input_parameters(): ) assert cache.get_region() == region assert cache.get_bucket() == bucket - assert cache._s3_cache._max_cache_items == max_s3_cache_items - assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon + assert cache._content_cache._max_cache_items == max_s3_cache_items + assert cache._content_cache._expiration_horizon == s3_cache_expiration_horizon assert ( cache._proprietary_model_id_manifest_key_cache._max_cache_items == max_semantic_version_cache_items diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index dc96636e9b..ad093640b7 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -253,6 +253,13 @@ def patched_retrieval_function( model_type=JumpStartModelType.PROPRIETARY, ) ) + + if datatype == HubContentType.MODEL: + _, _, _, model_name, model_version = id_info.split("/") + return JumpStartCachedContentValue( + formatted_content=get_spec_from_base_spec(model_id=model_name, version=model_version) + ) + # TODO: Implement if datatype == HubType.HUB: return None From 709bedc282ea3df50288b3fac3b1df2ed513ccf6 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Wed, 13 Mar 2024 20:43:29 +0000 Subject: [PATCH 40/42] bad rebase --- src/sagemaker/jumpstart/accessors.py | 1 - src/sagemaker/jumpstart/cache.py | 1 - src/sagemaker/jumpstart/estimator.py | 1 - src/sagemaker/jumpstart/factory/estimator.py | 2 -- src/sagemaker/jumpstart/factory/model.py | 4 ---- src/sagemaker/jumpstart/types.py | 13 +------------ .../hyperparameters/jumpstart/test_validate.py | 2 -- .../sagemaker/image_uris/jumpstart/test_common.py | 4 ---- .../sagemaker/jumpstart/curated_hub/test_utils.py | 1 - .../unit/sagemaker/jumpstart/test_notebook_utils.py | 1 - .../sagemaker/model_uris/jumpstart/test_common.py | 4 ---- .../jumpstart/test_resource_requirements.py | 1 - .../sagemaker/script_uris/jumpstart/test_common.py | 4 ---- 13 files changed, 1 insertion(+), 38 deletions(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index c9f805c225..dfc833ec28 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -257,7 +257,6 @@ def get_model_specs( hub_arn: Optional[str] = None, s3_client: Optional[boto3.client] = None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn: Optional[str] = None, ) -> JumpStartModelSpecs: """Returns model specs from JumpStart models cache. diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index d831d3023b..8d0f1832bf 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -58,7 +58,6 @@ DescribeHubContentsResponse, HubType, HubContentType, - HubDataType, ) from sagemaker.jumpstart.curated_hub import utils as hub_utils from sagemaker.jumpstart.enums import JumpStartModelType diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 4cdd540111..6406932924 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -534,7 +534,6 @@ def _validate_model_id_and_get_type_hook(): model_version=model_version, hub_arn=hub_arn, model_type=self.model_type, - hub_arn=hub_arn, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, role=role, diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index fa04b46a7c..fb598256fa 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -81,7 +81,6 @@ def get_init_kwargs( model_version: Optional[str] = None, hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, - hub_arn: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -141,7 +140,6 @@ def get_init_kwargs( model_version=model_version, hub_arn=hub_arn, model_type=model_type, - hub_arn=hub_arn, role=role, region=region, instance_count=instance_count, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 2fa538f33b..6f7a83cef1 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -550,7 +550,6 @@ def get_deploy_kwargs( model_version: Optional[str] = None, hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, - hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -585,7 +584,6 @@ def get_deploy_kwargs( model_version=model_version, hub_arn=hub_arn, model_type=model_type, - hub_arn=hub_arn, region=region, initial_instance_count=initial_instance_count, instance_type=instance_type, @@ -726,7 +724,6 @@ def get_init_kwargs( model_version: Optional[str] = None, hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, - hub_arn: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, instance_type: Optional[str] = None, @@ -760,7 +757,6 @@ def get_init_kwargs( model_version=model_version, hub_arn=hub_arn, model_type=model_type, - hub_arn=hub_arn, instance_type=instance_type, region=region, image_uri=image_uri, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 485dc0b6b9..01622a5462 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -15,6 +15,7 @@ from copy import deepcopy from enum import Enum from typing import Any, Dict, List, Optional, Set, Union +from sagemaker.session import Session from sagemaker.utils import get_instance_type_family, format_tags, Tags from sagemaker.enums import EndpointType from sagemaker.model_metrics import ModelMetrics @@ -1290,7 +1291,6 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", - "hub_arn", "instance_type", "tolerate_vulnerable_model", "tolerate_deprecated_model", @@ -1323,7 +1323,6 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", - "hub_arn", "tolerate_vulnerable_model", "tolerate_deprecated_model", "region", @@ -1337,7 +1336,6 @@ def __init__( model_version: Optional[str] = None, hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, - hub_arn: Optional[str] = None, region: Optional[str] = None, instance_type: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, @@ -1369,7 +1367,6 @@ def __init__( self.model_version = model_version self.hub_arn = hub_arn self.model_type = model_type - self.hub_arn = hub_arn self.instance_type = instance_type self.region = region self.image_uri = image_uri @@ -1404,7 +1401,6 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", - "hub_arn", "initial_instance_count", "instance_type", "region", @@ -1436,7 +1432,6 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): SERIALIZATION_EXCLUSION_SET = { "model_id", "model_version", - "hub_arn", "model_type", "hub_arn", "region", @@ -1485,7 +1480,6 @@ def __init__( self.model_version = model_version self.hub_arn = hub_arn self.model_type = model_type - self.hub_arn = hub_arn self.initial_instance_count = initial_instance_count self.instance_type = instance_type self.region = region @@ -1522,7 +1516,6 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", - "hub_arn", "instance_type", "instance_count", "region", @@ -1584,7 +1577,6 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", - "hub_arn", } def __init__( @@ -1714,7 +1706,6 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", - "hub_arn", "region", "inputs", "wait", @@ -1731,7 +1722,6 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", - "hub_arn", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1760,7 +1750,6 @@ def __init__( self.model_version = model_version self.hub_arn = hub_arn self.model_type = model_type - self.hub_arn = hub_arn self.region = region self.inputs = inputs self.wait = wait diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index 93d7098870..0f69cb572a 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -453,7 +453,6 @@ def add_options_to_hyperparameter(*largs, **kwargs): s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -517,7 +516,6 @@ def test_jumpstart_validate_all_hyperparameters( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index 45af6faeed..bd4383499d 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -56,7 +56,6 @@ def test_jumpstart_common_image_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -79,7 +78,6 @@ def test_jumpstart_common_image_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -102,7 +100,6 @@ def test_jumpstart_common_image_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -125,7 +122,6 @@ def test_jumpstart_common_image_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py index 2e7ec017bf..b4b2eaabb2 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py @@ -16,7 +16,6 @@ from sagemaker.jumpstart.types import HubArnExtractedInfo from sagemaker.jumpstart.curated_hub import utils from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME -from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo def test_get_info_from_hub_resource_arn(): diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index ed7c870a0e..a5d1ee3ac2 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -751,5 +751,4 @@ def test_get_model_url( s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 06587a2074..2bb327c26f 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -54,7 +54,6 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -74,7 +73,6 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -95,7 +93,6 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -116,7 +113,6 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index b2e055dd3c..2a4d913a75 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -57,7 +57,6 @@ def test_jumpstart_resource_requirements( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 14ad48082e..e1d3ef6ae1 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -52,7 +52,6 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, - hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, ) @@ -74,7 +73,6 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -95,7 +93,6 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -116,7 +113,6 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() From d2dd9be173658c250b9a65da2386720703dc6d8a Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Fri, 15 Mar 2024 15:34:09 +0000 Subject: [PATCH 41/42] trying to fix codecov --- .github/workflows/codebuild-ci.yml | 90 +++++++++++++++--------------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/.github/workflows/codebuild-ci.yml b/.github/workflows/codebuild-ci.yml index e72680be2a..a6e7e8e897 100644 --- a/.github/workflows/codebuild-ci.yml +++ b/.github/workflows/codebuild-ci.yml @@ -1,48 +1,48 @@ -name: PR Checks -on: - pull_request_target: +# name: PR Checks +# on: +# pull_request_target: -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.head_ref }} - cancel-in-progress: true +# concurrency: +# group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.head_ref }} +# cancel-in-progress: true -permissions: - id-token: write # This is required for requesting the JWT +# permissions: +# id-token: write # This is required for requesting the JWT -jobs: - codestyle-doc-tests: - runs-on: ubuntu-latest - steps: - - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v4 - with: - role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} - aws-region: us-west-2 - role-duration-seconds: 10800 - - name: Run Codestyle & Doc Tests - uses: aws-actions/aws-codebuild-run-build@v1 - with: - project-name: sagemaker-python-sdk-ci-codestyle-doc-tests - source-version-override: 'pr/${{ github.event.pull_request.number }}' - unit-tests: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: ["py38", "py39", "py310"] - steps: - - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v4 - with: - role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} - aws-region: us-west-2 - role-duration-seconds: 10800 - - name: Run Unit Tests - uses: aws-actions/aws-codebuild-run-build@v1 - with: - project-name: sagemaker-python-sdk-ci-unit-tests - source-version-override: 'pr/${{ github.event.pull_request.number }}' - env-vars-for-codebuild: | - PY_VERSION - env: - PY_VERSION: ${{ matrix.python-version }} +# jobs: +# codestyle-doc-tests: +# runs-on: ubuntu-latest +# steps: +# - name: Configure AWS Credentials +# uses: aws-actions/configure-aws-credentials@v4 +# with: +# role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} +# aws-region: us-west-2 +# role-duration-seconds: 10800 +# - name: Run Codestyle & Doc Tests +# uses: aws-actions/aws-codebuild-run-build@v1 +# with: +# project-name: sagemaker-python-sdk-ci-codestyle-doc-tests +# source-version-override: 'pr/${{ github.event.pull_request.number }}' +# unit-tests: +# runs-on: ubuntu-latest +# strategy: +# fail-fast: false +# matrix: +# python-version: ["py38", "py39", "py310"] +# steps: +# - name: Configure AWS Credentials +# uses: aws-actions/configure-aws-credentials@v4 +# with: +# role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} +# aws-region: us-west-2 +# role-duration-seconds: 10800 +# - name: Run Unit Tests +# uses: aws-actions/aws-codebuild-run-build@v1 +# with: +# project-name: sagemaker-python-sdk-ci-unit-tests +# source-version-override: 'pr/${{ github.event.pull_request.number }}' +# env-vars-for-codebuild: | +# PY_VERSION +# env: +# PY_VERSION: ${{ matrix.python-version }} From fa6a3ba326a57793dfebe221110bed00dfe2b65e Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Fri, 15 Mar 2024 18:48:24 +0000 Subject: [PATCH 42/42] uncomment codebuild-ci --- .github/workflows/codebuild-ci.yml | 90 +++++++++++++++--------------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/.github/workflows/codebuild-ci.yml b/.github/workflows/codebuild-ci.yml index a6e7e8e897..cf46938efb 100644 --- a/.github/workflows/codebuild-ci.yml +++ b/.github/workflows/codebuild-ci.yml @@ -1,48 +1,48 @@ -# name: PR Checks -# on: -# pull_request_target: +name: PR Checks +on: + pull_request_target: -# concurrency: -# group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.head_ref }} -# cancel-in-progress: true +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.head_ref }} + cancel-in-progress: true -# permissions: -# id-token: write # This is required for requesting the JWT +permissions: + id-token: write # This is required for requesting the JWT -# jobs: -# codestyle-doc-tests: -# runs-on: ubuntu-latest -# steps: -# - name: Configure AWS Credentials -# uses: aws-actions/configure-aws-credentials@v4 -# with: -# role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} -# aws-region: us-west-2 -# role-duration-seconds: 10800 -# - name: Run Codestyle & Doc Tests -# uses: aws-actions/aws-codebuild-run-build@v1 -# with: -# project-name: sagemaker-python-sdk-ci-codestyle-doc-tests -# source-version-override: 'pr/${{ github.event.pull_request.number }}' -# unit-tests: -# runs-on: ubuntu-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: ["py38", "py39", "py310"] -# steps: -# - name: Configure AWS Credentials -# uses: aws-actions/configure-aws-credentials@v4 -# with: -# role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} -# aws-region: us-west-2 -# role-duration-seconds: 10800 -# - name: Run Unit Tests -# uses: aws-actions/aws-codebuild-run-build@v1 -# with: -# project-name: sagemaker-python-sdk-ci-unit-tests -# source-version-override: 'pr/${{ github.event.pull_request.number }}' -# env-vars-for-codebuild: | -# PY_VERSION -# env: -# PY_VERSION: ${{ matrix.python-version }} +jobs: + codestyle-doc-tests: + runs-on: ubuntu-latest + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} + aws-region: us-west-2 + role-duration-seconds: 10800 + - name: Run Codestyle & Doc Tests + uses: aws-actions/aws-codebuild-run-build@v1 + with: + project-name: sagemaker-python-sdk-ci-codestyle-doc-tests + source-version-override: "pr/${{ github.event.pull_request.number }}" + unit-tests: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["py38", "py39", "py310"] + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} + aws-region: us-west-2 + role-duration-seconds: 10800 + - name: Run Unit Tests + uses: aws-actions/aws-codebuild-run-build@v1 + with: + project-name: sagemaker-python-sdk-ci-unit-tests + source-version-override: "pr/${{ github.event.pull_request.number }}" + env-vars-for-codebuild: | + PY_VERSION + env: + PY_VERSION: ${{ matrix.python-version }}