Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions malsim/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
MalSimulator,
MalSimulatorSettings,
run_simulation,
load_scenario,
load_scenario
)
from .mal_simulator import TTCMode
from .mal_simulator import TTCMode, ITERATIONS_LIMIT

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand All @@ -33,6 +33,10 @@ def main() -> None:
'-s', '--seed', type=int,
help="If set to a seed, simulator will use it as setting",
)
parser.add_argument(
'-m', '--max-iters', type=int, default=None,
help="Max number of steps in the simulation",
)
parser.add_argument(
'-t', '--ttc-mode', type=int,
help=(
Expand All @@ -48,8 +52,9 @@ def main() -> None:
scenario, MalSimulatorSettings(
seed=args.seed,
ttc_mode=TTCMode(args.ttc_mode),
attack_surface_skip_unnecessary=False
)
attack_surface_skip_unnecessary=False,
),
max_iter=args.max_iters or ITERATIONS_LIMIT
)

if args.output_attack_graph:
Expand Down
Empty file.
229 changes: 229 additions & 0 deletions malsim/agents/utils/greedy_a_star/algo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import heapq
import logging
from copy import copy
from itertools import count
from typing import Callable, Iterable, List, Tuple

from maltoolbox.attackgraph import AttackGraph, AttackGraphNode

from .utils import ttc_map, filter_defense, NoPath, merge_paths


def naive_a_star(
attack_graph: AttackGraph,
source: AttackGraphNode,
target: AttackGraphNode,
heuristic: Callable[[AttackGraphNode, AttackGraphNode], float] = lambda u, v: 0,
unusable_nodes: List[AttackGraphNode] = [],
) -> Tuple[List[AttackGraphNode], float]:
"""Returns a list of nodes in a heuristically shortest path between source and target using the A* ("A-star") algorithm.

Parameters
----------
attack_graph : AttackGraph

source : node
Starting node for path.

target : int
Target node for path.

heuristic : function
A function to evaluate the estimate of the distance from the a node to the target. The function takes two nodes arguments and must return a number.
The default heuristic is h=0 which is same as Dijkstra's algorithm.

Raises
------
ValueError
If no path exists between source and target.

Adapted from NetworkX and Sandor Berglund's thesis.
"""
if source.id not in attack_graph.nodes or target.id not in attack_graph.nodes:
raise ValueError(f"Either source {source} or target {target} is not in attack graph")

def cost(u, v):
return ttc_map(v.ttc) + 1

unusable_ids = {node.id for node in unusable_nodes}

# g_score[node.id] = the best-known cost from source to node
g_score = {}
# f_score[node.id] = g_score[node.id] + heuristic(node, target)
f_score = {}
# came_from[node.id] = the node we came from on the best path from source
came_from = {}

for node_id in attack_graph.nodes:
g_score[node_id] = float("inf")
f_score[node_id] = float("inf")
g_score[source.id] = 0
f_score[source.id] = heuristic(source, target)

# Priority queue of (f_score, tie-break counter, node)
# Tie-break counter ensures no direct comparison of nodes
queue = []
c = count()
heapq.heappush(queue, (f_score[source.id], next(c), source))

closed_set = set()

while queue:
_, __, current = heapq.heappop(queue)

# If we reached the target, reconstruct path
if current == target:
return _reconstruct_path(came_from, current), g_score[current.id]

if current.id in closed_set or current.id in unusable_ids:
continue
closed_set.add(current.id)

for neighbor in current.children:
if neighbor.id in closed_set or neighbor.id in unusable_ids:
continue
tentative_g = g_score[current.id] + cost(current, neighbor)
if tentative_g < g_score[neighbor.id]:
came_from[neighbor.id] = current
g_score[neighbor.id] = tentative_g
f_score[neighbor.id] = tentative_g + heuristic(neighbor, target)
heapq.heappush(queue, (f_score[neighbor.id], next(c), neighbor))

raise NoPath(target, f"Node {target.full_name} not reachable from {source.full_name}")

def _reconstruct_path(came_from, current):
"""Reconstructs path using came_from after we pop target from the queue."""
path = [current]
while current.id in came_from:
current = came_from[current.id]
path.append(current)
path.reverse()
return path



def correct_and_steps(
attack_graph: AttackGraph,
source: AttackGraphNode,
naive_path: List[AttackGraphNode],
unused_entry_points: List[AttackGraphNode],
unusable_nodes: List[AttackGraphNode],
):
new_path = naive_path

def node_collection_difference(
a: Iterable[AttackGraphNode], b: Iterable[AttackGraphNode]
) -> List[AttackGraphNode]:
ids = set(node.id for node in a).difference(node.id for node in b)
return [node for node in a if node.id in ids]

for i, node in enumerate(new_path):
if node.type == "and" and node.id != source.id:
# For each AND-step parent, not in the path
for and_parent in node_collection_difference(
filter_defense(node.parents), new_path[:i]
):
min_path, min_ttc = None, float("inf")
for sub_source in new_path[:i] + unused_entry_points:
try:
p, c = naive_a_star(attack_graph, sub_source, and_parent, unusable_nodes=unusable_nodes)
if c < min_ttc:
min_path, min_ttc = p, c
except NoPath:
pass
if not min_path:
raise NoPath(
node, f"No path to AND-step parent: {and_parent.full_name} for {node.full_name}"
)
new_path = merge_paths(new_path, min_path, i)

return new_path


def single_source_a_star(
attack_graph: AttackGraph,
source: AttackGraphNode,
target: AttackGraphNode,
other_sources: List[AttackGraphNode],
) -> List[AttackGraphNode]:
path, _ = naive_a_star(attack_graph, source, target)
unusable_nodes = []

def check_and_steps(path: List[AttackGraphNode]) -> bool:
and_steps = [
node for node in path
if node.type == "and"
and node.id != source.id
and node.id not in {
node.id for node in other_sources
}
]
return all(
parent in path
for node in and_steps
for parent in filter_defense(node.parents)
)

while not check_and_steps(other_sources + path):
try:
path = correct_and_steps(
attack_graph, source, path, other_sources, unusable_nodes
)
except NoPath as unreachable:
logging.debug(
unreachable.args[0]
)
unusable_nodes.append(unreachable.node)
old_path = copy(path)
path, _ = naive_a_star(
attack_graph, source, target, unusable_nodes=unusable_nodes
)
if path == old_path:
raise unreachable
return path


def single_target_a_star(
attack_graph: AttackGraph,
sources: AttackGraphNode | List[AttackGraphNode],
target: AttackGraphNode,
) -> List[AttackGraphNode]:
if isinstance(sources, AttackGraphNode):
sources = [sources]

ret_path, min_ttc = [], float("inf")

for i, source in enumerate(sources):
try:
path = single_source_a_star(
attack_graph, source, target, sources[:i] + sources[i + 1 :]
)
path_ttc = sum(ttc_map(node.ttc) for node in path)
if path_ttc < min_ttc:
ret_path, min_ttc = path, path_ttc
except NoPath:
pass

if len(ret_path) == 0:
raise NoPath(target, f"No path to target: {target}")

return ret_path


def greedy_a_star_attack(
attack_graph: AttackGraph,
sources: AttackGraphNode | List[AttackGraphNode],
targets: AttackGraphNode | List[AttackGraphNode],
) -> List[AttackGraphNode]:
if isinstance(sources, AttackGraphNode):
sources = [sources]
if isinstance(targets, AttackGraphNode):
targets = [targets]

ret_path = []

for target in targets:
new_path = single_target_a_star(attack_graph, sources, target)
ret_path = merge_paths(ret_path, new_path, len(ret_path))

return ret_path
65 changes: 65 additions & 0 deletions malsim/agents/utils/greedy_a_star/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Collection, Dict, Iterable, List, Optional
from maltoolbox.attackgraph import AttackGraphNode, AttackGraph
from maltoolbox.model import Model
import networkx as nx

def ttc_map(ttc: Optional[Dict]):
if not ttc:
return 0
elif ttc["name"] == "VeryHardAndUncertain":
return 50
elif ttc["name"] == "VeryHardAndCertain":
return 25
elif ttc["name"] == "HardAndUncertain":
return 5
elif ttc["name"] == "EasyAndCertain":
return 1
elif ttc["name"] == "Exponential":
return 1 / ttc["arguments"][0]

def filter_defense(c: Iterable[AttackGraphNode]) -> List[AttackGraphNode]:
return [node for node in c if node.type != "defense"]


class NoPath(Exception):
def __init__(self, node: AttackGraphNode, *args: object) -> None:
super().__init__(*args)
self.node = node

def merge_paths(
a: List[AttackGraphNode], b: List[AttackGraphNode], index: int
) -> List[AttackGraphNode]:
path_extension = [node for node in b if node not in a]
return a[:index] + path_extension + a[index:]

def to_nx(nodes: AttackGraph | Iterable[AttackGraphNode]):
if isinstance(nodes, AttackGraph):
nodes = list(nodes.nodes.values())
G = nx.DiGraph()

for node in nodes:
G.add_node(node.id, **node.to_dict())
G.nodes[node.id]["full_name"] = node.full_name

edges = [(node.id, child.id) for node in nodes for child in node.children]
edges += [(parent.id, node.id) for node in nodes for parent in node.parents]
G.add_edges_from(edges)

return G


def model_to_nx(model: Model) -> nx.Graph:
d = model._to_dict()
assets = d["assets"]

G = nx.Graph()

for id, vals in assets.items():
G.add_node(id, **vals)

for id, vals in assets.items():
neighbour_ids = [id for d in vals["associated_assets"].values() for id in d.keys() ]
for neighbour_id in neighbour_ids:
G.add_edge(id, neighbour_id)

return G
4 changes: 2 additions & 2 deletions malsim/mal_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,7 @@ def _attacker_step(

else:
logger.warning(
"Attacker could not compromise %s", node.full_name
"Attacker could not compromise untraversable %s", node.full_name
)

return successful_compromises, attempted_compromises
Expand Down Expand Up @@ -1286,7 +1286,7 @@ def run_simulation(
actions[agent_name] = [agent_action]
print(
f'Agent {agent_name} chose action: '
f'{agent_action.full_name}'
f'{agent_action.full_name} ({agent_action.type})'
)

# Store agent action
Expand Down
2 changes: 1 addition & 1 deletion malsim/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def load_simulator_agents(
if agent_class_name and agent_class_name not in agent_class_name_to_class:
raise LookupError(
f"Agent class '{agent_class_name}' not supported.\n"
f"Must be one of: {agent_class_name_to_class.values()}"
f"Must be one of: {list(agent_class_name_to_class.keys())}"
)

if agent_type == AgentType.ATTACKER:
Expand Down
Loading
Loading