Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
bencrabtree committed Mar 18, 2024
1 parent b50c557 commit 1af132e
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 23 deletions.
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/curated_hub/curated_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int):
self._sagemaker_session.import_hub_content(
document_schema_version=HubContentDocument_v2.SCHEMA_VERSION,
hub_content_name=model.model_id,
hub_content_version=model.version,
hub_name=self.hub_name,
hub_content_document=hub_content_document,
hub_content_type=HubContentType.MODEL,
Expand Down
2 changes: 0 additions & 2 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,8 +1434,6 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
"model_version",
"model_type",
"hub_arn",
"model_type",
"hub_arn",
"region",
"tolerate_deprecated_model",
"tolerate_vulnerable_model",
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/sagemaker/jumpstart/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def test_jumpstart_models_cache_get_model_specs(mock_cache):
accessors.JumpStartModelsAccessor.get_model_specs(
region=region, model_id=model_id, version=version
)
mock_cache.get_specs.assert_called_once_with(model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS)
mock_cache.get_specs.assert_called_once_with(
model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS
)
mock_cache.get_hub_model.assert_not_called()

accessors.JumpStartModelsAccessor.get_model_specs(
Expand Down
20 changes: 0 additions & 20 deletions tests/unit/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,26 +254,6 @@ def patched_retrieval_function(
)
)

if datatype == HubContentType.MODEL:
_, _, _, model_name, model_version = id_info.split("/")
return JumpStartCachedContentValue(
formatted_content=get_spec_from_base_spec(model_id=model_name, version=model_version)
)

# TODO: Implement
if datatype == HubType.HUB:
return None

if datatype == HubContentType.MODEL:
_, _, _, model_name, model_version = id_info.split("/")
return JumpStartCachedContentValue(
formatted_content=get_spec_from_base_spec(model_id=model_name, version=model_version)
)

# TODO: Implement
if datatype == HubType.HUB:
return None

raise ValueError(f"Bad value for datatype: {datatype}")


Expand Down

0 comments on commit 1af132e

Please sign in to comment.