From d56718a197a2413d79b54e9f7af40f625e4da8a4 Mon Sep 17 00:00:00 2001 From: "narrieta@microsoft" Date: Tue, 8 Oct 2024 15:09:51 -0700 Subject: [PATCH 1/2] Create walinuxagent nftable atomically --- azurelinuxagent/ga/firewall_manager.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/azurelinuxagent/ga/firewall_manager.py b/azurelinuxagent/ga/firewall_manager.py index 7776d5faa..a2a6bdb5e 100644 --- a/azurelinuxagent/ga/firewall_manager.py +++ b/azurelinuxagent/ga/firewall_manager.py @@ -392,13 +392,11 @@ def version(self): return self._version def setup(self): - shellutil.run_command(["nft", "add", "table", "ip", "walinuxagent"]) - shellutil.run_command(["nft", "add", "chain", "ip", "walinuxagent", "output", "{", "type", "filter", "hook", "output", "priority", "0", ";", "policy", "accept", ";", "}"]) - shellutil.run_command([ - "nft", "add", "rule", "ip", "walinuxagent", "output", "ip", "daddr", self._wire_server_address, - "tcp", "dport", "!=", "53", - "skuid", "!=", str(os.getuid()), - "ct", "state", "invalid,new", "counter", "drop"]) + shellutil.run_command(["nft", "-f", "-"], input=""" + add table ip walinuxagent + add chain ip walinuxagent output {{ type filter hook output priority 0 ; policy accept ; }} + add rule ip walinuxagent output ip daddr {0} tcp dport != 53 skuid != {1} ct state invalid,new counter drop + """.format(self._wire_server_address, os.getuid())) def remove(self): shellutil.run_command(["nft", "delete", "table", "walinuxagent"]) From a54be64ebe4c72dd967e7ba7773fe5fa6983abf9 Mon Sep 17 00:00:00 2001 From: "narrieta@microsoft" Date: Tue, 8 Oct 2024 21:57:26 -0700 Subject: [PATCH 2/2] Update unit tests --- tests/ga/test_firewall_manager.py | 17 +++++---- tests/lib/mock_firewall_command.py | 56 ++++++++---------------------- 2 files changed, 22 insertions(+), 51 deletions(-) diff --git a/tests/ga/test_firewall_manager.py b/tests/ga/test_firewall_manager.py index 9559067ed..d7868dfef 100644 --- a/tests/ga/test_firewall_manager.py +++ b/tests/ga/test_firewall_manager.py @@ -217,21 +217,20 @@ def test_setup_should_set_the_walinuxagent_table(self): firewall = NfTables('168.63.129.16') firewall.setup() - self.assertEqual( - [ - mock_nft.get_add_command("table"), - mock_nft.get_add_command("chain"), - mock_nft.get_add_command("rule"), - ], - mock_nft.call_list, - "Expected exactly 3 calls, to the add the walinuxagent table, output chain, and wireserver rule") + self.assertEqual(len(mock_nft.call_list), 1, "Expected exactly 1 call to execute a script to create the walinuxagent table; got {0}".format(mock_nft.call_list)) + + script = mock_nft.call_list[0] + self.assertIn("add table ip walinuxagent", script, "The setup script should to create the walinuxagent table. Script: {0}".format(script)) + self.assertIn("add chain ip walinuxagent output", script, "The setup script should to create the output chain. Script: {0}".format(script)) + self.assertIn("add rule ip walinuxagent output ", script, "The setup script should to create the rule to manage the output chain. Script: {0}".format(script)) + def test_remove_should_delete_the_walinuxagent_table(self): with MockNft() as mock_nft: firewall = NfTables('168.63.129.16') firewall.remove() - self.assertEqual([mock_nft.get_delete_command()], mock_nft.call_list, "Expected a call to delete the walinuxagent table") + self.assertEqual(['nft delete table walinuxagent'], mock_nft.call_list, "Expected a call to delete the walinuxagent table") def test_check_should_verify_all_rules(self): with MockNft() as mock_nft: diff --git a/tests/lib/mock_firewall_command.py b/tests/lib/mock_firewall_command.py index c2911a044..27a5961ea 100644 --- a/tests/lib/mock_firewall_command.py +++ b/tests/lib/mock_firewall_command.py @@ -242,15 +242,10 @@ def __init__(self): self._original_run_command = shellutil.run_command self._run_command_patcher = patch("azurelinuxagent.ga.firewall_manager.shellutil.run_command", side_effect=self._mock_run_command) # - # Return values for each nft command-line indexed by command name ("add", "delete", "list"). Each item is a (exit_code, stdout) tuple. - # These default values indicate success, and can be overridden with the set_*_return_values() methods. + # Return values for the "delete" and "list" options of the nft command. Each item is a (exit_code, stdout) tuple. + # The default values below indicate success, and can be overridden with the set_return_value() method. # self._return_values = { - "add": { - "table": (0, ''), # nft add table ip walinuxagent - "chain": (0, ''), # nft add chain ip walinuxagent output { type filter hook output priority 0 ; policy accept ; } - "rule": (0, ''), # nft add rule ip walinuxagent output ip daddr 168.63.129.16 tcp dport != 53 skuid != 0 ct state invalid,new counter drop - }, "delete": { "table": (0, ''), # nft delete table walinuxagent }, @@ -300,15 +295,17 @@ def __exit__(self, exc_type, exc_value, exc_traceback): def _mock_run_command(self, command, *args, **kwargs): if command[0] == 'nft': command_string = " ".join(command) + if command_string == "nft --version": + # return a hardcoded version string and don't add the command to the call list + return self._original_run_command(['echo', 'nftables v1.0.2 (Lester Gooch)'], *args, **kwargs) + elif command_string == 'nft -f -': + # if we are executing an nft script, add the script to the call list and return success with no stdout (empty string) + script = self._original_run_command(['cat'], *args, **kwargs) + self._call_list.append(script) + return self._original_run_command(['echo', '-n'], *args, **kwargs) + # get the exit code and stdout from the pre-defined table of return values and add the command to the call list exit_code, stdout = self.get_return_value(command_string) - script = \ -""" -cat << .. -{0} -.. -exit {1} -""".format(stdout, exit_code) - command = ['sh', '-c', script] + command = ['sh', '-c', "echo '{0}'; exit {1}".format(stdout, exit_code)] self._call_list.append(command_string) return self._original_run_command(command, *args, **kwargs) @@ -323,29 +320,18 @@ def set_return_value(self, command, target, return_value): """ Changes the return values for the mocked command """ + if command not in self._return_values or target not in self._return_values[command]: + raise Exception("Unexpected command: {0} {1}".format(command, target)) self._return_values[command][target] = return_value def get_return_value(self, command): """ Possible commands are: - nft add table ip walinuxagent - nft add chain ip walinuxagent output { type filter hook output priority 0 ; policy accept ; } - nft add rule ip walinuxagent output ip daddr 168.63.129.16 tcp dport != 53 skuid != 0 ct state invalid,new counter drop nft delete table walinuxagent nft --json list tables nft --json list table walinuxagent """ - r = r"nft add (?Ptable|chain|rule)" + \ - r"(ip walinuxagent output " + \ - r"(\{ type filter hook output priority 0 ; policy accept ; })" + \ - r"|" + \ - r"(ip daddr 168.63.129.16 tcp dport != 53 skuid != \d+ ct state invalid,new counter drop)" + \ - r")?" - match = re.match(r, command) - if match is not None: - target = match.group("target") - return self._return_values["add"][target] if command == "nft delete table walinuxagent": return self._return_values["delete"]["table"] match = re.match(r"nft --json list (?Ptables|table)( walinuxagent)?", command) @@ -354,20 +340,6 @@ def get_return_value(self, command): return self._return_values["list"][target] raise Exception("Unexpected command: {0}".format(command)) - @staticmethod - def get_add_command(target): - if target == "table": - return "nft add table ip walinuxagent" - if target == "chain": - return "nft add chain ip walinuxagent output { type filter hook output priority 0 ; policy accept ; }" - if target == "rule": - return "nft add rule ip walinuxagent output ip daddr 168.63.129.16 tcp dport != 53 skuid != {0} ct state invalid,new counter drop".format(os.getuid()) - raise Exception("Unexpected command target: {0}".format(target)) - - @staticmethod - def get_delete_command(): - return "nft delete table walinuxagent" - @staticmethod def get_list_command(target): if target == "tables":