From cd2dd7ae27a4e9f725d1e71d930f15c188a9f95b Mon Sep 17 00:00:00 2001 From: Samrudhi Sharma Date: Thu, 14 Mar 2024 16:48:29 -0700 Subject: [PATCH] Minor update to description --- src/sagemaker/serve/builder/model_builder.py | 12 ++++++------ .../sagemaker/serve/builder/test_model_builder.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 15c0279b63a..f602dfe759f 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -73,11 +73,11 @@ MEMORY_BUFFER_MULTIPLIER = 1.2 # 20% buffer VERSION_DETECTION_ERROR = ( "Please install accelerate and transformers for HuggingFace (HF) model " - "size calculations pip install 'sagemaker[huggingface]'" + "size calculations e.g. pip install 'sagemaker[huggingface]'" ) -# pylint: disable=attribute-defined-outside-init +# pylint: disable=attribute-defined-outside-init, disable=E1101 @dataclass class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): """Class that builds a deployable model. @@ -730,14 +730,14 @@ def _total_inference_model_size_mib(self): to add up to an additional 20% to the given model size as found by EleutherAI. """ try: - import accelerate.commands.estimate.estimate_command_parser - import accelerate.commands.estimate.gather_data + import accelerate.commands.estimate.estimate_command_parser as estimate_parser + import accelerate.commands.estimate.gather_data as estimate_gather dtypes = self.env_vars.get("dtypes", "float32") - parser = accelerate.commands.estimate.estimate_command_parser.estimate_command_parser() + parser = estimate_parser() args = parser.parse_args([self.model, "--dtypes", dtypes]) - output = accelerate.commands.estimate.gather_data.gather_data( + output = estimate_gather( args ) # "dtype", "Largest Layer", "Total Size Bytes", "Training using Adam" except ImportError as e: diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 3b60d13dfb7..a38c60b60e9 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -1343,8 +1343,8 @@ def test_build_for_transformers_happy_case_with_valid_gpu_fallback( self.assertEqual(model_builder._can_fit_on_single_gpu(), True) @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock()) - @patch("sagemaker.serve.builder.model_builder.estimate_command_parser") - @patch("sagemaker.serve.builder.model_builder.gather_data") + @patch("accelerate.commands.estimate.estimate_command_parser") + @patch("accelerate.commands.estimate.gather_data") @patch("sagemaker.image_uris.retrieve") @patch("sagemaker.djl_inference.model.urllib") @patch("sagemaker.djl_inference.model.json")