Skip to content

Commit b852e7f

Browse files
jgoyani1knikure
authored andcommitted
fix: update get_execution_role_arn from metadata file if present (#4388)
1 parent 8462f1a commit b852e7f

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

src/sagemaker/session.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4881,14 +4881,19 @@ def get_caller_identity_arn(self):
48814881
domain_id = metadata.get("DomainId")
48824882
user_profile_name = metadata.get("UserProfileName")
48834883
space_name = metadata.get("SpaceName")
4884+
execution_role_arn = metadata.get("ExecutionRoleArn")
48844885
try:
48854886
if domain_id is None:
48864887
instance_desc = self.sagemaker_client.describe_notebook_instance(
48874888
NotebookInstanceName=instance_name
48884889
)
48894890
return instance_desc["RoleArn"]
48904891

4891-
# In Space app, find execution role from DefaultSpaceSettings on domain level
4892+
# find execution role from the metadata file if present
4893+
if execution_role_arn is not None:
4894+
return execution_role_arn
4895+
4896+
# In Shared Space app, find execution role from DefaultSpaceSettings on domain level
48924897
if space_name is not None:
48934898
domain_desc = self.sagemaker_client.describe_domain(DomainId=domain_id)
48944899
return domain_desc["DefaultSpaceSettings"]["ExecutionRole"]

tests/unit/test_session.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,25 @@ def test_fallback_to_domain_if_role_unavailable_in_user_settings(boto_session):
626626
sess.sagemaker_client.describe_domain.assert_called_once_with(DomainId="d-kbnw5yk6tg8j")
627627

628628

629+
@patch(
630+
"six.moves.builtins.open",
631+
mock_open(
632+
read_data='{"ResourceName": "SageMakerInstance", '
633+
'"DomainId": "d-kbnw5yk6tg8j", '
634+
'"ExecutionRoleArn": "arn:aws:iam::369233609183:role/service-role/SageMakerRole-20171129T072388", '
635+
'"SpaceName": "space_name"}'
636+
),
637+
)
638+
@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, True))
639+
def test_get_caller_identity_arn_from_metadata_file_for_space(boto_session):
640+
sess = Session(boto_session)
641+
expected_role = "arn:aws:iam::369233609183:role/service-role/SageMakerRole-20171129T072388"
642+
643+
actual = sess.get_caller_identity_arn()
644+
645+
assert actual == expected_role
646+
647+
629648
@patch(
630649
"six.moves.builtins.open",
631650
mock_open(

0 commit comments

Comments
 (0)