Skip to content

Commit f8fc77e

Browse files
committed
Add HubNotebookDocument
1 parent 86467e9 commit f8fc77e

File tree

5 files changed

+112
-20
lines changed

5 files changed

+112
-20
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,18 @@ def _retrieval_function(
441441
return JumpStartCachedContentValue(
442442
formatted_content=model_specs
443443
)
444+
445+
if data_type == HubContentType.NOTEBOOK:
446+
hub_name, _, notebook_name, notebook_version = hub_utils.get_info_from_hub_resource_arn(id_info)
447+
response: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
448+
hub_name=hub_name,
449+
hub_content_name=notebook_name,
450+
hub_content_version=notebook_version,
451+
hub_content_type=data_type,
452+
)
453+
hub_notebook_description = DescribeHubContentResponse(response)
454+
return JumpStartCachedContentValue(formatted_content=hub_notebook_description)
455+
444456
if data_type == HubContentType.MODEL:
445457
hub_name, _, model_name, model_version = hub_utils.get_info_from_hub_resource_arn(id_info)
446458
hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
@@ -468,7 +480,7 @@ def _retrieval_function(
468480
response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name)
469481
hub_description = DescribeHubResponse(response)
470482
return JumpStartCachedContentValue(
471-
formatted_content=DescribeHubResponse(hub_description),
483+
formatted_content=hub_description,
472484
)
473485

474486
raise ValueError(

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
DescribeHubContentResponse,
4343
HubContentType,
4444
JumpStartModelSpecs,
45-
HubContentDocument,
45+
HubModelDocument,
4646
)
4747
from sagemaker.jumpstart.curated_hub.utils import (
4848
create_hub_bucket_if_it_does_not_exist,
@@ -389,7 +389,7 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int):
389389
f"{TASK_TAG_PREFIX}:TODO: pull from specs",
390390
]
391391

392-
hub_content_document = HubContentDocument(
392+
hub_content_document = HubModelDocument(
393393
json_obj_or_model_specs=model_specs, region=self.region
394394
)
395395

src/sagemaker/jumpstart/types.py

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -928,7 +928,7 @@ class HubContentDependency(JumpStartDataHolderType):
928928
Content can be scripts, model artifacts, datasets, or notebooks.
929929
"""
930930

931-
__slots__ = ["dependency_copy_path", "dependency_origin_path"]
931+
__slots__ = ["dependency_copy_path", "dependency_origin_path", "dependency_type"]
932932

933933
def __init__(self, json_obj: Dict[str, Any]) -> None:
934934
"""Instantiates HubContentDependency object
@@ -947,6 +947,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
947947

948948
self.dependency_copy_path: Optional[str] = json_obj.get("DependencyCopyPath", "")
949949
self.dependency_origin_path: Optional[str] = json_obj.get("DependencyOriginPath", "")
950+
self.dependency_type: Optional[str] = json_obj.get("DependencyType", "")
950951

951952

952953
class DescribeHubContentResponse(JumpStartDataHolderType):
@@ -1000,14 +1001,24 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
10001001
self.hub_content_display_name: str = json_obj["HubContentDisplayName"]
10011002
hub_region: Optional[str] = HubArnExtractedInfo.extract_region_from_arn(self.hub_arn)
10021003
self._region = hub_region
1003-
self.hub_content_document: str = HubContentDocument(
1004-
json_obj_or_model_specs=json_obj["HubContentDocument"], region=self._region
1005-
)
1004+
self.hub_content_type: HubContentType = json_obj["HubContentType"]
1005+
if self.hub_content_type == HubContentType.MODEL:
1006+
self.hub_content_document: HubContentDocument = HubModelDocument(
1007+
json_obj_or_model_specs=json_obj["HubContentDocument"], region=self._region
1008+
)
1009+
elif self.hub_content_type == HubContentType.NOTEBOOK:
1010+
self.hub_content_document: HubContentDocument = HubNotebookDocument(
1011+
json_obj=json_obj["HubContentDocument"], region=self._region
1012+
)
1013+
else:
1014+
raise ValueError(
1015+
f"[{self.hub_content_type}] is not a valid HubContentType. Should be one of: {[item.name for item in HubContentType]}."
1016+
)
1017+
10061018
self.hub_content_markdown: str = json_obj["HubContentMarkdown"]
10071019
self.hub_content_name: str = json_obj["HubContentName"]
10081020
self.hub_content_search_keywords: List[str] = json_obj["HubContentSearchKeywords"]
10091021
self.hub_content_status: str = json_obj["HubContentStatus"]
1010-
self.hub_content_type: HubContentType = json_obj["HubContentType"]
10111022
self.hub_content_version: str = json_obj["HubContentVersion"]
10121023
self.hub_name: str = json_obj["HubName"]
10131024

@@ -1425,7 +1436,7 @@ def from_describe_hub_content_response(self, response: DescribeHubContentRespons
14251436
self.version: str = response.hub_content_version
14261437
# CuratedHub is regionalized
14271438
hub_region: Optional[str] = response.get_hub_region()
1428-
hub_content_document: HubContentDocument = HubContentDocument(
1439+
hub_content_document: HubModelDocument = HubModelDocument(
14291440
response.hub_content_document, region=hub_region
14301441
)
14311442
self.url: str = hub_content_document.url
@@ -1590,10 +1601,11 @@ def get_framework(self) -> str:
15901601
return self.model_id.split("-")[0]
15911602

15921603

1593-
class HubContentDocument(JumpStartDataHolderType):
1604+
class HubModelDocument(JumpStartDataHolderType):
1605+
"""Data class for model type HubContentDocument from session.describe_hub_content()."""
1606+
15941607
SCHEMA_VERSION = "2.0.0"
15951608

1596-
"""Data class for HubContentDocument from session.describe_hub_content()."""
15971609
__slots__ = [
15981610
"url",
15991611
"min_sdk_version",
@@ -1659,14 +1671,21 @@ class HubContentDocument(JumpStartDataHolderType):
16591671
"dependencies",
16601672
"_region",
16611673
]
1674+
16621675
_non_serializable_slots = ["_region"]
16631676

16641677
def __init__(
16651678
self,
16661679
json_obj_or_model_specs: Union[Dict[str, Any], JumpStartModelSpecs],
16671680
region: str,
16681681
) -> None:
1669-
self._region = region # Handle region
1682+
"""Instantiates HubModelDocument object.
1683+
1684+
Args:
1685+
json_obj (Dict[str, Any]): Dictionary representation of hub content document.
1686+
"""
1687+
1688+
self._region = region
16701689
if isinstance(json_obj_or_model_specs, Dict):
16711690
self.from_json(json_obj_or_model_specs)
16721691
elif isinstance(json_obj_or_model_specs, JumpStartModelSpecs):
@@ -1689,7 +1708,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
16891708
]
16901709
self.training_supported: bool = bool(json_obj["TrainingSupported"])
16911710
self.incremental_training_supported: bool = bool(json_obj["IncrementalTrainingSupported"])
1692-
self.dependencies = json_obj.get("Dependencies", [])
1711+
self.dependencies: List[HubContentDependency] = [
1712+
HubContentDependency(dep) for dep in json_obj["Dependencies"]
1713+
]
16931714

16941715
self.dynamic_container_deployment_supported: Optional[bool] = (
16951716
bool(json_obj.get("DynamicContainerDeploymentSupported"))
@@ -1979,6 +2000,45 @@ def to_model_specs(
19792000
pass
19802001

19812002

2003+
class HubNotebookDocument(JumpStartDataHolderType):
2004+
"""Data class for notebook type HubContentDocument from session.describe_hub_content()."""
2005+
2006+
SCHEMA_VERSION = "1.0.0"
2007+
2008+
__slots__ = ["notebook_location", "dependencies", "_region"]
2009+
2010+
_non_serializable_slots = ["_region"]
2011+
2012+
def __init__(self, json_obj: Dict[str, Any], region: str) -> None:
2013+
"""Instantiates HubNotebookDocument object.
2014+
2015+
Args:
2016+
json_obj (Dict[str, Any]): Dictionary representation of hub content document.
2017+
"""
2018+
self._region = region
2019+
self.from_json(json_obj)
2020+
2021+
def from_json(self, json_obj: Dict[str, Any]) -> None:
2022+
"""Sets fields in object based on json.
2023+
2024+
Args:
2025+
json_obj (Dict[str, Any]): Dictionary representation of hub content description.
2026+
"""
2027+
self.notebook_location = json_obj["NotebookLocation"]
2028+
self.dependencies: List[HubContentDependency] = [
2029+
HubContentDependency(dep) for dep in json_obj["Dependencies"]
2030+
]
2031+
2032+
def get_schema_version(self) -> str:
2033+
"""Returns schema version."""
2034+
return self.SCHEMA_VERSION
2035+
2036+
def get_region(self) -> str:
2037+
return self._region
2038+
2039+
HubContentDocument = Union[HubModelDocument, HubNotebookDocument]
2040+
2041+
19822042
class HubContentSummary(JumpStartDataHolderType):
19832043
"""Data class for the HubContentSummary from session.list_hub_contents()."""
19842044

@@ -2021,7 +2081,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
20212081
self._region: Optional[str] = HubArnExtractedInfo.extract_region_from_arn(
20222082
self.hub_content_arn
20232083
)
2024-
self.hub_content_document: HubContentDocument = HubContentDocument(
2084+
self.hub_content_document: HubModelDocument = HubModelDocument(
20252085
json_obj_or_model_specs=json_obj["HubContentDocument"], region=self._region
20262086
)
20272087
self.hub_content_name: str = json_obj["HubContentName"]

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8602,3 +8602,14 @@
86028602
"Dependencies": [],
86038603
},
86048604
}
8605+
8606+
BASE_HUB_NOTEBOOK_DOCUMENT = {
8607+
"NotebookLocation": "s3://sagemaker-test-objects-do-not-delete/tensorflow-notebooks/tensorflow-ic-bit-s-r101x3-ilsvrc2012-classification-1-inference.ipynb",
8608+
"Dependencies": [
8609+
{
8610+
"DependencyOriginPath": "sagemaker-test-objects-do-not-delete/tensorflow-notebooks/tensorflow-ic-bit-s-r101x3-ilsvrc2012-classification-1-inference.ipynb",
8611+
"DependencyCopyPath": "sagemaker-hubs-us-west-2-802376408542/default-hub-1667253603.746/Notebook/pentest-3-notebook-1667933000.49/0.0.1/notebook.ipynb",
8612+
"DependencyType": "Notebook",
8613+
}
8614+
],
8615+
}

tests/unit/sagemaker/jumpstart/test_types.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,13 @@
1818
JumpStartInstanceTypeVariants,
1919
JumpStartModelSpecs,
2020
JumpStartModelHeader,
21-
HubContentDocument,
21+
HubModelDocument,
22+
)
23+
from tests.unit.sagemaker.jumpstart.constants import (
24+
BASE_SPEC,
25+
HUB_MODEL_DOCUMENT_DICTS,
26+
BASE_HUB_NOTEBOOK_DOCUMENT,
2227
)
23-
from tests.unit.sagemaker.jumpstart.constants import BASE_SPEC, HUB_MODEL_DOCUMENT_DICTS
2428

2529
INSTANCE_TYPE_VARIANT = JumpStartInstanceTypeVariants(
2630
{
@@ -894,16 +898,21 @@ def test_jumpstart_training_artifact_key_instance_variants():
894898
)
895899

896900

897-
def test_hub_content_document_from_model_specs():
901+
def test_hub_model_document_from_model_specs():
898902
specs1 = JumpStartModelSpecs(BASE_SPEC)
899903
region = "us-west-2"
900-
specs2 = HubContentDocument(specs1, region)
904+
specs2 = HubModelDocument(specs1, region)
905+
# TODO: Implement
906+
pass
907+
908+
909+
def test_hub_model_document_from_json_obj():
901910
# TODO: Implement
902911
pass
903912

904913

905-
def test_hub_content_document_from_json_obj():
906-
# TODO: implement
914+
def test_hub_notebook_document():
915+
# TODO: Implement
907916
pass
908917

909918

0 commit comments

Comments
 (0)