diff --git a/awscli/customizations/configure/sso.py b/awscli/customizations/configure/sso.py index be5f60639bed..6abb54d7d017 100644 --- a/awscli/customizations/configure/sso.py +++ b/awscli/customizations/configure/sso.py @@ -439,10 +439,12 @@ def _handle_single_account(self, accounts): def _handle_multiple_accounts(self, accounts): available_accounts_msg = ( 'There are {} AWS accounts available to you.\n' + 'Use arrow keys to navigate, type to filter, and press Enter to select.\n' ) uni_print(available_accounts_msg.format(len(accounts))) selected_account = self._selector( - accounts, display_format=display_account + accounts, display_format=display_account, enable_filter=True, + no_results_message='No matching accounts found' ) sso_account_id = selected_account['accountId'] return sso_account_id @@ -456,6 +458,8 @@ def _prompt_for_account(self, sso, sso_token): accounts = self._get_all_accounts(sso, sso_token)['accountList'] if not accounts: raise RuntimeError('No AWS accounts are available to you.') + # Sort accounts by accountName for consistent ordering + accounts = sorted(accounts, key=lambda x: x.get('accountName', '')) if len(accounts) == 1: sso_account_id = self._handle_single_account(accounts) else: @@ -489,6 +493,8 @@ def _prompt_for_role(self, sso, sso_token, sso_account_id): if not roles: error_msg = 'No roles are available for the account {}' raise RuntimeError(error_msg.format(sso_account_id)) + # Sort roles by roleName for consistent ordering + roles = sorted(roles, key=lambda x: x.get('roleName', '')) if len(roles) == 1: sso_role_name = self._handle_single_role(roles) else: diff --git a/awscli/customizations/wizard/ui/selectmenu.py b/awscli/customizations/wizard/ui/selectmenu.py index 67cb7a098edf..700b5940b787 100644 --- a/awscli/customizations/wizard/ui/selectmenu.py +++ b/awscli/customizations/wizard/ui/selectmenu.py @@ -23,7 +23,10 @@ from prompt_toolkit.utils import get_cwidth -def select_menu(items, display_format=None, max_height=10): +def select_menu( + items, display_format=None, max_height=10, enable_filter=False, + no_results_message=None +): """Presents a list of options and allows the user to select one. This presents a static list of options and prompts the user to select one. @@ -42,6 +45,12 @@ def select_menu(items, display_format=None, max_height=10): :type max_height: int :param max_height: The max number of items to show in the list at a time. + :type enable_filter: bool + :param enable_filter: Enable keyboard filtering of items. + + :type no_results_message: str + :param no_results_message: Message to show when filtering returns no results. + :returns: The selected element from the items list. """ app_bindings = KeyBindings() @@ -51,8 +60,20 @@ def exit_app(event): event.app.exit(exception=KeyboardInterrupt, style='class:aborting') min_height = min(max_height, len(items)) + if enable_filter: + # Add 1 to height for filter line + min_height = min(max_height + 1, len(items) + 1) + menu_control = FilterableSelectionMenuControl( + items, display_format=display_format, + no_results_message=no_results_message + ) + else: + menu_control = SelectionMenuControl( + items, display_format=display_format + ) + menu_window = Window( - SelectionMenuControl(items, display_format=display_format), + menu_control, always_hide_cursor=False, height=Dimension(min=min_height, max=min_height), scroll_offsets=ScrollOffsets(), @@ -122,6 +143,8 @@ def is_focusable(self): def preferred_width(self, max_width): items = self._get_items() + if not items: + return self.MIN_WIDTH if self._display_format: items = (self._display_format(i) for i in items) max_item_width = max(get_cwidth(i) for i in items) @@ -188,6 +211,157 @@ def app_result(event): return kb +class FilterableSelectionMenuControl(SelectionMenuControl): + """Menu that supports keyboard filtering of items""" + + def __init__(self, items, display_format=None, cursor='>', no_results_message=None): + super().__init__(items, display_format=display_format, cursor=cursor) + self._filter_text = '' + self._filtered_items = items if items else [] + self._all_items = items if items else [] + self._filter_enabled = True + self._no_results_message = no_results_message or 'No matching items found' + + def _get_items(self): + if callable(self._all_items): + self._all_items = self._all_items() + return self._filtered_items + + def preferred_width(self, max_width): + # Ensure minimum width for search display + min_search_width = max(20, len("Search: " + self._filter_text) + 5) + + # Get width from filtered items + items = self._filtered_items + if not items: + # Width for no results message + no_results_width = get_cwidth(self._no_results_message) + 4 + return max(no_results_width, min_search_width) + + if self._display_format: + items_display = [self._display_format(i) for i in items] + else: + items_display = [str(i) for i in items] + + if items_display: + max_item_width = max(get_cwidth(i) for i in items_display) + max_item_width += self._format_overhead + else: + max_item_width = self.MIN_WIDTH + + max_item_width = max(max_item_width, min_search_width) + + if max_item_width < self.MIN_WIDTH: + max_item_width = self.MIN_WIDTH + return min(max_width, max_item_width) + + def _update_filtered_items(self): + """Update the filtered items based on the current filter text""" + if not self._filter_text: + self._filtered_items = self._all_items + else: + filter_lower = self._filter_text.lower() + self._filtered_items = [ + item + for item in self._all_items + if filter_lower + in ( + self._display_format(item) + if self._display_format + else str(item) + ).lower() + ] + + # Reset selection if it's out of bounds + if self._selection >= len(self._filtered_items): + self._selection = 0 + + def preferred_height(self, width, max_height, wrap_lines, get_line_prefix): + # Add 1 extra line for the filter display + return min(max_height, len(self._get_items()) + 1) + + def create_content(self, width, height): + def get_line(i): + # First line shows the filter + if i == 0: + filter_display = ( + f"Search: {self._filter_text}_" + if self._filter_enabled + else f"Search: {self._filter_text}" + ) + return [('class:filter', filter_display)] + + # Show "No results" message if filtered items is empty + if not self._filtered_items: + if i == 1: + return [ + ('class:no-results', f' {self._no_results_message}') + ] + return [('', '')] + + # Adjust for the filter line + item_index = i - 1 + if item_index >= len(self._filtered_items): + return [('', '')] + + item = self._filtered_items[item_index] + is_selected = item_index == self._selection + return self._menu_item_fragment(item, is_selected, width) + + # Ensure at least 2 lines (search + no results or items) + line_count = max(2, len(self._filtered_items) + 1) + cursor_y = self._selection + 1 if self._filtered_items else 0 + + return UIContent( + get_line=get_line, + cursor_position=Point(x=0, y=cursor_y), + line_count=line_count, + ) + + def get_key_bindings(self): + kb = KeyBindings() + + @kb.add('up') + def move_up(event): + if len(self._filtered_items) > 0: + self._move_cursor(-1) + + @kb.add('down') + def move_down(event): + if len(self._filtered_items) > 0: + self._move_cursor(1) + + @kb.add('enter') + def app_result(event): + if len(self._filtered_items) > 0: + result = self._filtered_items[self._selection] + event.app.exit(result=result) + + @kb.add('backspace') + def delete_char(event): + if self._filter_text: + self._filter_text = self._filter_text[:-1] + self._update_filtered_items() + + @kb.add('c-u') + def clear_filter(event): + self._filter_text = '' + self._update_filtered_items() + + # Add support for typing any character + from string import printable + + for char in printable: + if char not in ('\n', '\r', '\t'): + + @kb.add(char) + def add_char(event, c=char): + self._filter_text += c + self._update_filtered_items() + + return kb + + class CollapsableSelectionMenuControl(SelectionMenuControl): """Menu that collapses to text with selection when loses focus""" diff --git a/tests/unit/customizations/configure/test_sso.py b/tests/unit/customizations/configure/test_sso.py index 8c60e95a0ae3..0533d2ba88d8 100644 --- a/tests/unit/customizations/configure/test_sso.py +++ b/tests/unit/customizations/configure/test_sso.py @@ -747,6 +747,8 @@ class PTKStubber: _ALLOWED_SELECT_MENU_KWARGS = { "display_format", "max_height", + "enable_filter", + "no_results_message", } def __init__(self, user_inputs=None): @@ -1484,6 +1486,66 @@ def test_configure_sso_suggests_values_from_sessions( sso_cmd = sso_cmd_factory(session=session) assert sso_cmd(args, parsed_globals) == 0 + def test_multiple_accounts_uses_filterable_menu( + self, + sso_cmd, + ptk_stubber, + aws_config, + stub_sso_list_roles, + stub_sso_list_accounts, + mock_do_sso_login, + botocore_session, + args, + parsed_globals, + configure_sso_legacy_inputs, + capsys, + ): + """Test that multiple accounts selection shows filter instructions""" + inputs = configure_sso_legacy_inputs + selected_account_id = inputs.account_id_select.answer["accountId"] + ptk_stubber.user_inputs = inputs + + stub_sso_list_accounts(inputs.account_id_select.expected_choices) + stub_sso_list_roles( + inputs.role_name_select.expected_choices, + expected_account_id=selected_account_id, + ) + + sso_cmd(args, parsed_globals) + + # Verify the filter instructions are shown for multiple accounts + stdout = capsys.readouterr().out + assert "Use arrow keys to navigate, type to filter" in stdout + assert "There are 2 AWS accounts available to you" in stdout + + def test_single_account_does_not_use_filterable_menu( + self, + sso_cmd, + ptk_stubber, + aws_config, + stub_simple_single_item_sso_responses, + mock_do_sso_login, + botocore_session, + args, + parsed_globals, + configure_sso_legacy_inputs, + account_id, + role_name, + capsys, + ): + """Test that single account does not show filter instructions""" + inputs = configure_sso_legacy_inputs + inputs.skip_account_and_role_selection() + ptk_stubber.user_inputs = inputs + stub_simple_single_item_sso_responses(account_id, role_name) + + sso_cmd(args, parsed_globals) + + # Verify the filter instructions are NOT shown for single account + stdout = capsys.readouterr().out + assert "Use arrow keys to navigate, type to filter" not in stdout + assert "The only AWS account available to you is" in stdout + def test_single_account_single_role_device_code_fallback( self, sso_cmd, @@ -1524,12 +1586,109 @@ def test_single_account_single_role_device_code_fallback( ], ) + def test_accounts_are_sorted_by_name( + self, + sso_cmd, + ): + """Test that accounts are sorted by accountName""" + from unittest.mock import Mock + + # Create accounts in non-alphabetical order + unsorted_accounts = [ + { + "accountId": "333333333333", + "accountName": "Zulu", + "emailAddress": "zulu@example.com", + }, + { + "accountId": "111111111111", + "accountName": "Alpha", + "emailAddress": "alpha@example.com", + }, + { + "accountId": "222222222222", + "accountName": "Beta", + "emailAddress": "beta@example.com", + }, + ] + + # Mock _get_all_accounts to return unsorted accounts + sso_cmd._get_all_accounts = Mock( + return_value={"accountList": unsorted_accounts} + ) + + # Create a mock selector to capture the sorted accounts + mock_selector = Mock( + return_value={ + "accountId": "111111111111", + "accountName": "Alpha", + "emailAddress": "alpha@example.com", + } + ) + sso_cmd._selector = mock_selector + + # Call the method directly to test sorting + sso_account_id = sso_cmd._prompt_for_account( + Mock(), {"accessToken": "test-token"} # sso client mock + ) + + # Verify the selector was called with sorted accounts + mock_selector.assert_called_once() + called_accounts = mock_selector.call_args[0][0] + + # Verify accounts were sorted alphabetically by name + assert len(called_accounts) == 3 + assert called_accounts[0]["accountName"] == "Alpha" + assert called_accounts[1]["accountName"] == "Beta" + assert called_accounts[2]["accountName"] == "Zulu" + assert sso_account_id == "111111111111" + + def test_roles_are_sorted_by_name( + self, + sso_cmd, + ): + """Test that roles are sorted by roleName""" + from unittest.mock import Mock + + # Create roles in non-alphabetical order + unsorted_roles = [ + {"roleName": "PowerUser"}, + {"roleName": "Administrator"}, + {"roleName": "ReadOnly"}, + ] + + # Mock _get_all_roles to return unsorted roles + sso_cmd._get_all_roles = Mock(return_value={"roleList": unsorted_roles}) + + # Create a mock selector to capture the sorted roles + mock_selector = Mock(return_value="Administrator") + original_selector = sso_cmd._selector + sso_cmd._selector = mock_selector + + # Call the method directly to test sorting + sso_role_name = sso_cmd._prompt_for_role( + Mock(), # sso client mock + {"accessToken": "test-token"}, + "111111111111", + ) + + # Verify the selector was called with sorted role names + mock_selector.assert_called() + called_role_names = mock_selector.call_args[0][0] + + # Verify roles were sorted alphabetically + assert len(called_role_names) == 3 + assert called_role_names[0] == "Administrator" + assert called_role_names[1] == "PowerUser" + assert called_role_names[2] == "ReadOnly" + assert sso_role_name == "Administrator" + class TestPrintConclusion: def test_print_conclusion_default_profile_with_credentials( self, sso_cmd, capsys ): - sso_cmd._print_conclusion(True, 'default') + sso_cmd._print_conclusion(True, "default") captured = capsys.readouterr() assert ( "The AWS CLI is now configured to use the default profile." @@ -1564,7 +1723,7 @@ def test_print_conclusion_sso_configuration(self, sso_cmd, capsys): def test_print_conclusion_default_profile_case_insensitive( selfself, sso_cmd, capsys ): - sso_cmd._print_conclusion(True, 'DEFAULT') + sso_cmd._print_conclusion(True, "DEFAULT") captured = capsys.readouterr() assert ( "The AWS CLI is now configured to use the default profile." @@ -1750,9 +1909,7 @@ def test_strips_extra_whitespace(self): def test_can_provide_toolbar(self): toolbar = "Toolbar content" - self.prompter.get_value( - "default_value", "Prompt Text", toolbar=toolbar - ) + self.prompter.get_value("default_value", "Prompt Text", toolbar=toolbar) self.assert_expected_toolbar(toolbar) def test_can_provide_prompt_format(self): diff --git a/tests/unit/customizations/wizard/ui/test_filterable_menu.py b/tests/unit/customizations/wizard/ui/test_filterable_menu.py new file mode 100644 index 000000000000..d357f0edde47 --- /dev/null +++ b/tests/unit/customizations/wizard/ui/test_filterable_menu.py @@ -0,0 +1,341 @@ +# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import unittest +from unittest.mock import Mock, MagicMock, patch +from prompt_toolkit.key_binding import KeyBindings +from prompt_toolkit.layout.screen import Point + +from awscli.customizations.wizard.ui.selectmenu import ( + FilterableSelectionMenuControl, + select_menu, +) + + +class TestFilterableSelectionMenuControl(unittest.TestCase): + """Test cases for FilterableSelectionMenuControl""" + + def setUp(self): + """Set up test fixtures""" + self.test_items = [ + {"id": "1", "name": "Production", "env": "prod"}, + {"id": "2", "name": "Development", "env": "dev"}, + {"id": "3", "name": "Staging", "env": "stage"}, + {"id": "4", "name": "Testing", "env": "test"}, + ] + + def display_format(item): + return f"{item['name']} ({item['env']})" + + self.display_format = display_format + self.control = FilterableSelectionMenuControl( + self.test_items, display_format=self.display_format + ) + + def test_init(self): + """Test initialization of FilterableSelectionMenuControl""" + self.assertEqual(self.control._filter_text, "") + self.assertEqual(self.control._filtered_items, self.test_items) + self.assertEqual(self.control._all_items, self.test_items) + self.assertTrue(self.control._filter_enabled) + self.assertEqual(self.control._selection, 0) + self.assertEqual( + self.control._no_results_message, "No matching items found" + ) + + def test_filter_update_with_matching_text(self): + """Test filtering with matching text""" + self.control._filter_text = "prod" + self.control._update_filtered_items() + + self.assertEqual(len(self.control._filtered_items), 1) + self.assertEqual(self.control._filtered_items[0]["name"], "Production") + + def test_filter_update_with_non_matching_text(self): + """Test filtering with non-matching text""" + self.control._filter_text = "xyz" + self.control._update_filtered_items() + + self.assertEqual(len(self.control._filtered_items), 0) + + def test_filter_update_case_insensitive(self): + """Test that filtering is case insensitive""" + self.control._filter_text = "DEV" + self.control._update_filtered_items() + + self.assertEqual(len(self.control._filtered_items), 1) + self.assertEqual(self.control._filtered_items[0]["name"], "Development") + + def test_filter_update_partial_match(self): + """Test filtering with partial matches""" + self.control._filter_text = "ing" + self.control._update_filtered_items() + + # Should match both 'Staging' and 'Testing' + self.assertEqual(len(self.control._filtered_items), 2) + names = [item["name"] for item in self.control._filtered_items] + self.assertIn("Staging", names) + self.assertIn("Testing", names) + + def test_filter_clears_when_empty(self): + """Test that empty filter shows all items""" + self.control._filter_text = "prod" + self.control._update_filtered_items() + self.assertEqual(len(self.control._filtered_items), 1) + + self.control._filter_text = "" + self.control._update_filtered_items() + self.assertEqual(len(self.control._filtered_items), 4) + + def test_selection_reset_when_filtered(self): + """Test that selection resets when filter results change""" + self.control._selection = 2 + self.control._filter_text = "prod" + self.control._update_filtered_items() + + # Selection should reset to 0 when filtering + self.assertEqual(self.control._selection, 0) + + def test_preferred_width_with_items(self): + """Test preferred width calculation with items""" + max_width = 100 + width = self.control.preferred_width(max_width) + + # Width should be based on the longest item + self.assertGreater(width, 0) + self.assertLessEqual(width, max_width) + + def test_preferred_width_with_empty_filter_result(self): + """Test preferred width when filter returns no results""" + self.control._filter_text = "xyz" + self.control._update_filtered_items() + + max_width = 100 + width = self.control.preferred_width(max_width) + + # Should return minimum width for no results message + # Default message is "No matching items found" which is 24 chars + 4 = 28 + self.assertGreaterEqual(width, 20) + + def test_preferred_height(self): + """Test preferred height calculation""" + max_height = 10 + height = self.control.preferred_height(100, max_height, False, None) + + # Should be items count + 1 for search line + expected = min(max_height, len(self.test_items) + 1) + self.assertEqual(height, expected) + + def test_create_content_with_items(self): + """Test content creation with filtered items""" + content = self.control.create_content(50, 10) + + # First line should be the search prompt + self.assertEqual(content.line_count, len(self.test_items) + 1) + + # Cursor should be on the first item (line 1, after search line) + self.assertEqual(content.cursor_position.y, 1) + + def test_create_content_with_filter(self): + """Test content creation with active filter""" + self.control._filter_text = "test" + self.control._update_filtered_items() + + content = self.control.create_content(50, 10) + + # Should have search line + 1 filtered item + self.assertEqual(content.line_count, 2) + + def test_create_content_no_results(self): + """Test content creation when no results match filter""" + self.control._filter_text = "xyz" + self.control._update_filtered_items() + + content = self.control.create_content(50, 10) + + # Should have at least 2 lines (search + no results message) + self.assertGreaterEqual(content.line_count, 2) + + # Cursor should be on search line when no results + self.assertEqual(content.cursor_position.y, 0) + + def test_key_bindings(self): + """Test that key bindings are properly set up""" + kb = self.control.get_key_bindings() + + self.assertIsInstance(kb, KeyBindings) + + # Check that essential keys are bound + bindings = kb.bindings + key_names = [str(b.keys[0]) for b in bindings] + + self.assertIn("Keys.Up", key_names) + self.assertIn("Keys.Down", key_names) + self.assertIn("Keys.ControlM", key_names) # 'enter' is mapped to 'c-m' + self.assertIn( + "Keys.ControlH", key_names + ) # 'backspace' is mapped to 'c-h' + self.assertIn("Keys.ControlU", key_names) + + def test_move_cursor_with_filtered_items(self): + """Test cursor movement with filtered items""" + self.control._filter_text = "ing" + self.control._update_filtered_items() + + # Should have 2 items (Staging and Testing) + self.assertEqual(len(self.control._filtered_items), 2) + + # Move down + self.control._move_cursor(1) + self.assertEqual(self.control._selection, 1) + + # Move down again (should wrap to 0) + self.control._move_cursor(1) + self.assertEqual(self.control._selection, 0) + + # Move up (should wrap to last item) + self.control._move_cursor(-1) + self.assertEqual(self.control._selection, 1) + + def test_custom_no_results_message(self): + """Test custom no results message""" + custom_message = "No AWS accounts match your search" + control = FilterableSelectionMenuControl( + self.test_items, + display_format=self.display_format, + no_results_message=custom_message, + ) + + self.assertEqual(control._no_results_message, custom_message) + + # Test that custom message appears in content when no results + control._filter_text = "xyz" + control._update_filtered_items() + + content = control.create_content(50, 10) + # Get the second line (index 1) which should contain the no results message + line = content.get_line(1) + self.assertEqual(len(line), 1) + self.assertEqual(line[0][0], "class:no-results") + self.assertIn(custom_message, line[0][1]) + + +class TestSelectMenuWithFilter(unittest.TestCase): + """Test select_menu function with filtering enabled""" + + @patch("awscli.customizations.wizard.ui.selectmenu.Application") + def test_select_menu_with_filter_enabled(self, mock_app_class): + """Test that select_menu uses FilterableSelectionMenuControl when enable_filter=True""" + mock_app = MagicMock() + mock_app_class.return_value = mock_app + mock_app.run.return_value = {"id": "1", "name": "Test"} + + items = [{"id": "1", "name": "Test"}, {"id": "2", "name": "Prod"}] + + result = select_menu(items, enable_filter=True) + + # Verify Application was created + mock_app_class.assert_called_once() + + # Verify the result + self.assertEqual(result, {"id": "1", "name": "Test"}) + + @patch("awscli.customizations.wizard.ui.selectmenu.Application") + def test_select_menu_with_filter_disabled(self, mock_app_class): + """Test that select_menu uses SelectionMenuControl when enable_filter=False""" + mock_app = MagicMock() + mock_app_class.return_value = mock_app + mock_app.run.return_value = {"id": "1", "name": "Test"} + + items = [{"id": "1", "name": "Test"}, {"id": "2", "name": "Prod"}] + + result = select_menu(items, enable_filter=False) + + # Verify Application was created + mock_app_class.assert_called_once() + + # Verify the result + self.assertEqual(result, {"id": "1", "name": "Test"}) + + +class TestFilteringIntegration(unittest.TestCase): + """Integration tests for filtering in SSO configuration""" + + def test_filter_with_display_format(self): + """Test filtering works correctly with display_format function""" + accounts = [ + { + "accountId": "111111111111", + "accountName": "Production", + "emailAddress": "prod@example.com", + }, + { + "accountId": "222222222222", + "accountName": "Development", + "emailAddress": "dev@example.com", + }, + { + "accountId": "333333333333", + "accountName": "Staging", + "emailAddress": "staging@example.com", + }, + ] + + def display_account(account): + return f"{account['accountName']}, {account['emailAddress']} ({account['accountId']})" + + control = FilterableSelectionMenuControl( + accounts, display_format=display_account + ) + + # Test filtering by account name + control._filter_text = "prod" + control._update_filtered_items() + self.assertEqual(len(control._filtered_items), 1) + self.assertEqual( + control._filtered_items[0]["accountName"], "Production" + ) + + # Test filtering by email + control._filter_text = "dev@" + control._update_filtered_items() + self.assertEqual(len(control._filtered_items), 1) + self.assertEqual( + control._filtered_items[0]["accountName"], "Development" + ) + + # Test filtering by account ID + control._filter_text = "3333" + control._update_filtered_items() + self.assertEqual(len(control._filtered_items), 1) + self.assertEqual(control._filtered_items[0]["accountName"], "Staging") + + def test_empty_items_list(self): + """Test handling of empty items list""" + control = FilterableSelectionMenuControl([]) + + self.assertEqual(control._filtered_items, []) + self.assertEqual(control._all_items, []) + + # Should not crash when filtering + control._filter_text = "test" + control._update_filtered_items() + self.assertEqual(control._filtered_items, []) + + # Should return reasonable width + width = control.preferred_width(100) + self.assertGreater(width, 0) + + +if __name__ == "__main__": + unittest.main()