@@ -928,7 +928,7 @@ class HubContentDependency(JumpStartDataHolderType):
928
928
Content can be scripts, model artifacts, datasets, or notebooks.
929
929
"""
930
930
931
- __slots__ = ["dependency_copy_path" , "dependency_origin_path" ]
931
+ __slots__ = ["dependency_copy_path" , "dependency_origin_path" , "dependency_type" ]
932
932
933
933
def __init__ (self , json_obj : Dict [str , Any ]) -> None :
934
934
"""Instantiates HubContentDependency object
@@ -947,6 +947,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
947
947
948
948
self .dependency_copy_path : Optional [str ] = json_obj .get ("DependencyCopyPath" , "" )
949
949
self .dependency_origin_path : Optional [str ] = json_obj .get ("DependencyOriginPath" , "" )
950
+ self .dependency_type : Optional [str ] = json_obj .get ("DependencyType" , "" )
950
951
951
952
952
953
class DescribeHubContentResponse (JumpStartDataHolderType ):
@@ -1000,14 +1001,24 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
1000
1001
self .hub_content_display_name : str = json_obj ["HubContentDisplayName" ]
1001
1002
hub_region : Optional [str ] = HubArnExtractedInfo .extract_region_from_arn (self .hub_arn )
1002
1003
self ._region = hub_region
1003
- self .hub_content_document : str = HubContentDocument (
1004
- json_obj_or_model_specs = json_obj ["HubContentDocument" ], region = self ._region
1005
- )
1004
+ self .hub_content_type : HubContentType = json_obj ["HubContentType" ]
1005
+ if self .hub_content_type == HubContentType .MODEL :
1006
+ self .hub_content_document : HubContentDocument = HubModelDocument (
1007
+ json_obj_or_model_specs = json_obj ["HubContentDocument" ], region = self ._region
1008
+ )
1009
+ elif self .hub_content_type == HubContentType .NOTEBOOK :
1010
+ self .hub_content_document : HubContentDocument = HubNotebookDocument (
1011
+ json_obj = json_obj ["HubContentDocument" ], region = self ._region
1012
+ )
1013
+ else :
1014
+ raise ValueError (
1015
+ f"[{ self .hub_content_type } ] is not a valid HubContentType. Should be one of: { [item .name for item in HubContentType ]} ."
1016
+ )
1017
+
1006
1018
self .hub_content_markdown : str = json_obj ["HubContentMarkdown" ]
1007
1019
self .hub_content_name : str = json_obj ["HubContentName" ]
1008
1020
self .hub_content_search_keywords : List [str ] = json_obj ["HubContentSearchKeywords" ]
1009
1021
self .hub_content_status : str = json_obj ["HubContentStatus" ]
1010
- self .hub_content_type : HubContentType = json_obj ["HubContentType" ]
1011
1022
self .hub_content_version : str = json_obj ["HubContentVersion" ]
1012
1023
self .hub_name : str = json_obj ["HubName" ]
1013
1024
@@ -1425,7 +1436,7 @@ def from_describe_hub_content_response(self, response: DescribeHubContentRespons
1425
1436
self .version : str = response .hub_content_version
1426
1437
# CuratedHub is regionalized
1427
1438
hub_region : Optional [str ] = response .get_hub_region ()
1428
- hub_content_document : HubContentDocument = HubContentDocument (
1439
+ hub_content_document : HubModelDocument = HubModelDocument (
1429
1440
response .hub_content_document , region = hub_region
1430
1441
)
1431
1442
self .url : str = hub_content_document .url
@@ -1590,10 +1601,11 @@ def get_framework(self) -> str:
1590
1601
return self .model_id .split ("-" )[0 ]
1591
1602
1592
1603
1593
- class HubContentDocument (JumpStartDataHolderType ):
1604
+ class HubModelDocument (JumpStartDataHolderType ):
1605
+ """Data class for model type HubContentDocument from session.describe_hub_content()."""
1606
+
1594
1607
SCHEMA_VERSION = "2.0.0"
1595
1608
1596
- """Data class for HubContentDocument from session.describe_hub_content()."""
1597
1609
__slots__ = [
1598
1610
"url" ,
1599
1611
"min_sdk_version" ,
@@ -1659,14 +1671,21 @@ class HubContentDocument(JumpStartDataHolderType):
1659
1671
"dependencies" ,
1660
1672
"_region" ,
1661
1673
]
1674
+
1662
1675
_non_serializable_slots = ["_region" ]
1663
1676
1664
1677
def __init__ (
1665
1678
self ,
1666
1679
json_obj_or_model_specs : Union [Dict [str , Any ], JumpStartModelSpecs ],
1667
1680
region : str ,
1668
1681
) -> None :
1669
- self ._region = region # Handle region
1682
+ """Instantiates HubModelDocument object.
1683
+
1684
+ Args:
1685
+ json_obj (Dict[str, Any]): Dictionary representation of hub content document.
1686
+ """
1687
+
1688
+ self ._region = region
1670
1689
if isinstance (json_obj_or_model_specs , Dict ):
1671
1690
self .from_json (json_obj_or_model_specs )
1672
1691
elif isinstance (json_obj_or_model_specs , JumpStartModelSpecs ):
@@ -1689,7 +1708,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
1689
1708
]
1690
1709
self .training_supported : bool = bool (json_obj ["TrainingSupported" ])
1691
1710
self .incremental_training_supported : bool = bool (json_obj ["IncrementalTrainingSupported" ])
1692
- self .dependencies = json_obj .get ("Dependencies" , [])
1711
+ self .dependencies : List [HubContentDependency ] = [
1712
+ HubContentDependency (dep ) for dep in json_obj ["Dependencies" ]
1713
+ ]
1693
1714
1694
1715
self .dynamic_container_deployment_supported : Optional [bool ] = (
1695
1716
bool (json_obj .get ("DynamicContainerDeploymentSupported" ))
@@ -1979,6 +2000,45 @@ def to_model_specs(
1979
2000
pass
1980
2001
1981
2002
2003
+ class HubNotebookDocument (JumpStartDataHolderType ):
2004
+ """Data class for notebook type HubContentDocument from session.describe_hub_content()."""
2005
+
2006
+ SCHEMA_VERSION = "1.0.0"
2007
+
2008
+ __slots__ = ["notebook_location" , "dependencies" , "_region" ]
2009
+
2010
+ _non_serializable_slots = ["_region" ]
2011
+
2012
+ def __init__ (self , json_obj : Dict [str , Any ], region : str ) -> None :
2013
+ """Instantiates HubNotebookDocument object.
2014
+
2015
+ Args:
2016
+ json_obj (Dict[str, Any]): Dictionary representation of hub content document.
2017
+ """
2018
+ self ._region = region
2019
+ self .from_json (json_obj )
2020
+
2021
+ def from_json (self , json_obj : Dict [str , Any ]) -> None :
2022
+ """Sets fields in object based on json.
2023
+
2024
+ Args:
2025
+ json_obj (Dict[str, Any]): Dictionary representation of hub content description.
2026
+ """
2027
+ self .notebook_location = json_obj ["NotebookLocation" ]
2028
+ self .dependencies : List [HubContentDependency ] = [
2029
+ HubContentDependency (dep ) for dep in json_obj ["Dependencies" ]
2030
+ ]
2031
+
2032
+ def get_schema_version (self ) -> str :
2033
+ """Returns schema version."""
2034
+ return self .SCHEMA_VERSION
2035
+
2036
+ def get_region (self ) -> str :
2037
+ return self ._region
2038
+
2039
+ HubContentDocument = Union [HubModelDocument , HubNotebookDocument ]
2040
+
2041
+
1982
2042
class HubContentSummary (JumpStartDataHolderType ):
1983
2043
"""Data class for the HubContentSummary from session.list_hub_contents()."""
1984
2044
@@ -2021,7 +2081,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
2021
2081
self ._region : Optional [str ] = HubArnExtractedInfo .extract_region_from_arn (
2022
2082
self .hub_content_arn
2023
2083
)
2024
- self .hub_content_document : HubContentDocument = HubContentDocument (
2084
+ self .hub_content_document : HubModelDocument = HubModelDocument (
2025
2085
json_obj_or_model_specs = json_obj ["HubContentDocument" ], region = self ._region
2026
2086
)
2027
2087
self .hub_content_name : str = json_obj ["HubContentName" ]
0 commit comments