Skip to content

Commit

Permalink
update types
Browse files Browse the repository at this point in the history
  • Loading branch information
bencrabtree committed Mar 18, 2024
1 parent 1af132e commit c9f79fd
Showing 1 changed file with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ def get_s3_reference(self, dependency_type: HubContentDependencyType):
return getattr(self, dependency_type.value)

@property
def inference_artifact_s3_reference(self):
def inference_artifact_s3_reference(self) -> Optional[S3ObjectLocation]:
"""Retrieves s3 reference for model inference artifact"""
return create_s3_object_reference_from_uri(
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.INFERENCE)
)

@property
def training_artifact_s3_reference(self):
def training_artifact_s3_reference(self) -> Optional[S3ObjectLocation]:
"""Retrieves s3 reference for model training artifact"""
if not self.model_specs.training_supported:
return None
Expand All @@ -66,14 +66,14 @@ def training_artifact_s3_reference(self):
)

@property
def inference_script_s3_reference(self):
def inference_script_s3_reference(self) -> Optional[S3ObjectLocation]:
"""Retrieves s3 reference for model inference script"""
return create_s3_object_reference_from_uri(
self._jumpstart_script_s3_uri(JumpStartScriptScope.INFERENCE)
)

@property
def training_script_s3_reference(self):
def training_script_s3_reference(self) -> Optional[S3ObjectLocation]:
"""Retrieves s3 reference for model training script"""
if not self.model_specs.training_supported:
return None
Expand All @@ -82,21 +82,21 @@ def training_script_s3_reference(self):
)

@property
def default_training_dataset_s3_reference(self):
def default_training_dataset_s3_reference(self) -> S3ObjectLocation:
"""Retrieves s3 reference for s3 directory containing model training datasets"""
if not self.model_specs.training_supported:
return None
return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix())

@property
def demo_notebook_s3_reference(self):
def demo_notebook_s3_reference(self) -> S3ObjectLocation:
"""Retrieves s3 reference for model demo jupyter notebook"""
framework = self.model_specs.get_framework()
key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb"
return S3ObjectLocation(self._get_bucket_name(), key)

@property
def markdown_s3_reference(self):
def markdown_s3_reference(self) -> S3ObjectLocation:
"""Retrieves s3 reference for model markdown"""
framework = self.model_specs.get_framework()
key = f"{framework}-metadata/{self.model_specs.model_id}.md"
Expand All @@ -106,7 +106,7 @@ def _get_bucket_name(self) -> str:
"""Retrieves s3 bucket"""
return self._bucket

def __get_training_dataset_prefix(self) -> str:
def _get_training_dataset_prefix(self) -> Optional[str]:
"""Retrieves training dataset location"""
return self.studio_specs.get("defaultDataKey")

Expand Down

0 comments on commit c9f79fd

Please sign in to comment.