Skip to content

Commit 8cbc09c

Browse files
authored
Create walinuxagent nftable atomically (#3239)
* Create walinuxagent nftable atomically * Update unit tests --------- Co-authored-by: narrieta@microsoft <narrieta>
1 parent e55a98a commit 8cbc09c

File tree

3 files changed

+27
-58
lines changed

3 files changed

+27
-58
lines changed

azurelinuxagent/ga/firewall_manager.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -408,13 +408,11 @@ def version(self):
408408
return self._version
409409

410410
def setup(self):
411-
shellutil.run_command(["nft", "add", "table", "ip", "walinuxagent"])
412-
shellutil.run_command(["nft", "add", "chain", "ip", "walinuxagent", "output", "{", "type", "filter", "hook", "output", "priority", "0", ";", "policy", "accept", ";", "}"])
413-
shellutil.run_command([
414-
"nft", "add", "rule", "ip", "walinuxagent", "output", "ip", "daddr", self._wire_server_address,
415-
"tcp", "dport", "!=", "53",
416-
"skuid", "!=", str(os.getuid()),
417-
"ct", "state", "invalid,new", "counter", "drop"])
411+
shellutil.run_command(["nft", "-f", "-"], input="""
412+
add table ip walinuxagent
413+
add chain ip walinuxagent output {{ type filter hook output priority 0 ; policy accept ; }}
414+
add rule ip walinuxagent output ip daddr {0} tcp dport != 53 skuid != {1} ct state invalid,new counter drop
415+
""".format(self._wire_server_address, os.getuid()))
418416

419417
def remove(self):
420418
shellutil.run_command(["nft", "delete", "table", "walinuxagent"])

tests/ga/test_firewall_manager.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -217,21 +217,20 @@ def test_setup_should_set_the_walinuxagent_table(self):
217217
firewall = NfTables('168.63.129.16')
218218
firewall.setup()
219219

220-
self.assertEqual(
221-
[
222-
mock_nft.get_add_command("table"),
223-
mock_nft.get_add_command("chain"),
224-
mock_nft.get_add_command("rule"),
225-
],
226-
mock_nft.call_list,
227-
"Expected exactly 3 calls, to the add the walinuxagent table, output chain, and wireserver rule")
220+
self.assertEqual(len(mock_nft.call_list), 1, "Expected exactly 1 call to execute a script to create the walinuxagent table; got {0}".format(mock_nft.call_list))
221+
222+
script = mock_nft.call_list[0]
223+
self.assertIn("add table ip walinuxagent", script, "The setup script should to create the walinuxagent table. Script: {0}".format(script))
224+
self.assertIn("add chain ip walinuxagent output", script, "The setup script should to create the output chain. Script: {0}".format(script))
225+
self.assertIn("add rule ip walinuxagent output ", script, "The setup script should to create the rule to manage the output chain. Script: {0}".format(script))
226+
228227

229228
def test_remove_should_delete_the_walinuxagent_table(self):
230229
with MockNft() as mock_nft:
231230
firewall = NfTables('168.63.129.16')
232231
firewall.remove()
233232

234-
self.assertEqual([mock_nft.get_delete_command()], mock_nft.call_list, "Expected a call to delete the walinuxagent table")
233+
self.assertEqual(['nft delete table walinuxagent'], mock_nft.call_list, "Expected a call to delete the walinuxagent table")
235234

236235
def test_check_should_verify_all_rules(self):
237236
with MockNft() as mock_nft:

tests/lib/mock_firewall_command.py

Lines changed: 14 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -242,15 +242,10 @@ def __init__(self):
242242
self._original_run_command = shellutil.run_command
243243
self._run_command_patcher = patch("azurelinuxagent.ga.firewall_manager.shellutil.run_command", side_effect=self._mock_run_command)
244244
#
245-
# Return values for each nft command-line indexed by command name ("add", "delete", "list"). Each item is a (exit_code, stdout) tuple.
246-
# These default values indicate success, and can be overridden with the set_*_return_values() methods.
245+
# Return values for the "delete" and "list" options of the nft command. Each item is a (exit_code, stdout) tuple.
246+
# The default values below indicate success, and can be overridden with the set_return_value() method.
247247
#
248248
self._return_values = {
249-
"add": {
250-
"table": (0, ''), # nft add table ip walinuxagent
251-
"chain": (0, ''), # nft add chain ip walinuxagent output { type filter hook output priority 0 ; policy accept ; }
252-
"rule": (0, ''), # nft add rule ip walinuxagent output ip daddr 168.63.129.16 tcp dport != 53 skuid != 0 ct state invalid,new counter drop
253-
},
254249
"delete": {
255250
"table": (0, ''), # nft delete table walinuxagent
256251
},
@@ -300,15 +295,17 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
300295
def _mock_run_command(self, command, *args, **kwargs):
301296
if command[0] == 'nft':
302297
command_string = " ".join(command)
298+
if command_string == "nft --version":
299+
# return a hardcoded version string and don't add the command to the call list
300+
return self._original_run_command(['echo', 'nftables v1.0.2 (Lester Gooch)'], *args, **kwargs)
301+
elif command_string == 'nft -f -':
302+
# if we are executing an nft script, add the script to the call list and return success with no stdout (empty string)
303+
script = self._original_run_command(['cat'], *args, **kwargs)
304+
self._call_list.append(script)
305+
return self._original_run_command(['echo', '-n'], *args, **kwargs)
306+
# get the exit code and stdout from the pre-defined table of return values and add the command to the call list
303307
exit_code, stdout = self.get_return_value(command_string)
304-
script = \
305-
"""
306-
cat << ..
307-
{0}
308-
..
309-
exit {1}
310-
""".format(stdout, exit_code)
311-
command = ['sh', '-c', script]
308+
command = ['sh', '-c', "echo '{0}'; exit {1}".format(stdout, exit_code)]
312309
self._call_list.append(command_string)
313310
return self._original_run_command(command, *args, **kwargs)
314311

@@ -323,29 +320,18 @@ def set_return_value(self, command, target, return_value):
323320
"""
324321
Changes the return values for the mocked command
325322
"""
323+
if command not in self._return_values or target not in self._return_values[command]:
324+
raise Exception("Unexpected command: {0} {1}".format(command, target))
326325
self._return_values[command][target] = return_value
327326

328327
def get_return_value(self, command):
329328
"""
330329
Possible commands are:
331330
332-
nft add table ip walinuxagent
333-
nft add chain ip walinuxagent output { type filter hook output priority 0 ; policy accept ; }
334-
nft add rule ip walinuxagent output ip daddr 168.63.129.16 tcp dport != 53 skuid != 0 ct state invalid,new counter drop
335331
nft delete table walinuxagent
336332
nft --json list tables
337333
nft --json list table walinuxagent
338334
"""
339-
r = r"nft add (?P<target>table|chain|rule)" + \
340-
r"(ip walinuxagent output " + \
341-
r"(\{ type filter hook output priority 0 ; policy accept ; })" + \
342-
r"|" + \
343-
r"(ip daddr 168.63.129.16 tcp dport != 53 skuid != \d+ ct state invalid,new counter drop)" + \
344-
r")?"
345-
match = re.match(r, command)
346-
if match is not None:
347-
target = match.group("target")
348-
return self._return_values["add"][target]
349335
if command == "nft delete table walinuxagent":
350336
return self._return_values["delete"]["table"]
351337
match = re.match(r"nft --json list (?P<target>tables|table)( walinuxagent)?", command)
@@ -354,20 +340,6 @@ def get_return_value(self, command):
354340
return self._return_values["list"][target]
355341
raise Exception("Unexpected command: {0}".format(command))
356342

357-
@staticmethod
358-
def get_add_command(target):
359-
if target == "table":
360-
return "nft add table ip walinuxagent"
361-
if target == "chain":
362-
return "nft add chain ip walinuxagent output { type filter hook output priority 0 ; policy accept ; }"
363-
if target == "rule":
364-
return "nft add rule ip walinuxagent output ip daddr 168.63.129.16 tcp dport != 53 skuid != {0} ct state invalid,new counter drop".format(os.getuid())
365-
raise Exception("Unexpected command target: {0}".format(target))
366-
367-
@staticmethod
368-
def get_delete_command():
369-
return "nft delete table walinuxagent"
370-
371343
@staticmethod
372344
def get_list_command(target):
373345
if target == "tables":

0 commit comments

Comments
 (0)