Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions malsim/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def _validate_scenario_property_rules(

# a way to lookup attack steps for asset types
asset_type_step_names = {
asset_type.name: [a.name for a in asset_type.attack_steps]
for asset_type in graph.lang_graph.assets
asset_type_name: [step_name for step_name in asset_type.attack_steps]
for asset_type_name, asset_type in graph.lang_graph.assets.items()
}

if rules is None:
Expand Down
109 changes: 45 additions & 64 deletions malsim/sims/mal_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,20 @@ def create_blank_observation(self, default_obs_state=-1):
right_field = getattr(assoc, right_field_name)
for left_asset in left_field:
for right_asset in right_field:

observation["model_edges_ids"].append(
[
self._model_asset_id_to_index[left_asset.id],
self._model_asset_id_to_index[right_asset.id]
]
)

observation["model_edges_type"].append(
self._model_assoc_type_to_index[
self._get_association_full_name(assoc)])
# TODO: change this when lang_class factory not used anymore
assoc.__class__.__name__.removeprefix('Association_')
]
)


np_obs = {
Expand Down Expand Up @@ -336,11 +341,23 @@ def _format_info(self, info):
@functools.lru_cache(maxsize=None)
def observation_space(self, agent=None):
# For now, an `object` is an attack step
num_assets = len(self.model.assets)
num_assets = len(self.attack_graph.model.assets)
num_steps = len(self.attack_graph.nodes)
num_lang_asset_types = len(self.lang_graph.assets)
num_lang_attack_steps = len(self.lang_graph.attack_steps)
num_lang_association_types = len(self.lang_graph.associations)

unique_step_types = set()
for asset_type in self.lang_graph.assets.values():
unique_step_types |= set(asset_type.attack_steps.values())
num_lang_attack_steps = len(unique_step_types)

unique_assoc_type_names = set()
for asset_type in self.lang_graph.assets.values():
for assoc_type in asset_type.associations.values():
unique_assoc_type_names.add(
assoc_type.full_name
)
num_lang_association_types = len(unique_assoc_type_names)

num_attack_graph_edges = len(
self._blank_observation["attack_graph_edges"])
num_model_edges = len(
Expand Down Expand Up @@ -521,17 +538,30 @@ def _create_mapping_tables(self):

# Lookup lists index to attribute
self._index_to_id = [n.id for n in self.attack_graph.nodes]
self._index_to_full_name = [n.full_name
for n in self.attack_graph.nodes]
self._index_to_asset_type = [n.name for n in self.lang_graph.assets]
self._index_to_step_name = [n.asset.name + ":" + n.name
for n in self.lang_graph.attack_steps]
self._index_to_model_asset_id = [int(asset.id) for asset in \
self.attack_graph.model.assets]
self._index_to_model_assoc_type = [assoc.name + '_' + \
assoc.left_field.asset.name + '_' + \
assoc.right_field.asset.name \
for assoc in self.lang_graph.associations]
self._index_to_full_name = (
[n.full_name for n in self.attack_graph.nodes]
)
self._index_to_asset_type = (
[n.name for n in self.lang_graph.assets.values()]
)

unique_step_type_names = {
n.full_name
for asset in self.lang_graph.assets.values()
for n in asset.attack_steps.values()
}
self._index_to_step_name = list(unique_step_type_names)

self._index_to_model_asset_id = (
[int(asset.id) for asset in self.attack_graph.model.assets]
)

unique_assoc_type_names = {
assoc.full_name
for asset in self.lang_graph.assets.values()
for assoc in asset.associations.values()
}
self._index_to_model_assoc_type = list(unique_assoc_type_names)

# Lookup dicts attribute to index
self._id_to_index = {
Comment on lines 540 to 567
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of these will be possible to drop once mal-lang/mal-toolbox#105 is merged.

Expand Down Expand Up @@ -604,55 +634,6 @@ def action_to_node(
node = self.index_to_node(step_idx)
return node

def _get_association_full_name(self, association) -> str:
"""Get association full name

TODO: Remove this method once the language graph integration is
complete in the mal-toolbox because the language graph associations
will use their full names for the name property

Arguments:
association - the association whose full name will be returned

Return:
A string containing the association name and the name of each of the
two asset types for the left and right fields separated by
underscores.
"""

assoc_name = association.__class__.__name__
if '_' in assoc_name:
# TODO: Not actually a to-do, but just an extra clarification that
# this is an ugly hack that will work for now until we get the
# unique association names. Right now some associations already
# use the asset types as part of their name if there are multiple
# associations with the same name.
return assoc_name

left_field_name, right_field_name = \
self.model.get_association_field_names(association)
left_field = getattr(association, left_field_name)
right_field = getattr(association, right_field_name)
lang_assoc = self.lang_graph.get_association_by_fields_and_assets(
left_field_name,
right_field_name,
left_field[0].type,
right_field[0].type
)
if lang_assoc is None:
raise LookupError('Failed to find association for fields '
'"%s" "%s" and asset types "%s" "%s"!' % (
left_field_name,
right_field_name,
left_field[0].type,
right_field[0].type
)
)
assoc_full_name = lang_assoc.name + '_' + \
lang_assoc.left_field.asset.name + '_' + \
lang_assoc.right_field.asset.name
return assoc_full_name


def _initialize_agents(self) -> dict[str, list[int]]:
"""Initialize agent rewards, observations, and action surfaces
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ requires-python = ">=3.10"
dependencies = [
"py2neo>=2021.2.3",
"python-jsonschema-objects>=0.4.1",
"mal-toolbox==0.1.12",
"mal-toolbox==0.2",
"numpy>=1.21.4",
"pettingzoo>=1.24.2",
"gymnasium==1.0",
Expand Down
97 changes: 72 additions & 25 deletions tests/test_example_scenarios.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,48 @@
from malsim.scenario import create_simulator_from_scenario


def test_bfs_vs_bfs_state_and_reward():
sim, sim_config = create_simulator_from_scenario(
'tests/testdata/scenarios/bfs_vs_bfs_scenario.yml')
"tests/testdata/scenarios/bfs_vs_bfs_scenario.yml"
)
obs, infos = sim.reset()

attacker_agent_id = next(iter(sim.get_attacker_agents()))
defender_agent_id = next(iter(sim.get_defender_agents()), None)

# Initialize defender and attacker according to classes
defender_class = sim_config['agents'][defender_agent_id]['agent_class']
defender_class = sim_config["agents"][defender_agent_id]["agent_class"]
defender_agent = defender_class({})

attacker_class = sim_config['agents'][attacker_agent_id]['agent_class']
attacker_class = sim_config["agents"][attacker_agent_id]["agent_class"]
attacker_agent = attacker_class({})

total_reward_defender = 0
total_reward_attacker = 0

attacker = sim.attack_graph.attackers[0]
attacker_actions = [sim._id_to_index[n.id] for n in attacker.entry_points]
defender_actions = [sim._id_to_index[n.id]
for n in sim.attack_graph.nodes
if n.is_enabled_defense()]
attacker_actions = [n.full_name for n in attacker.entry_points]
defender_actions = [
n.full_name for n in sim.attack_graph.nodes if n.is_enabled_defense()
]

while True:
defender_action = defender_agent.compute_action_from_dict(
obs[defender_agent_id],
infos[defender_agent_id]["action_mask"])
obs[defender_agent_id], infos[defender_agent_id]["action_mask"]
)

attacker_action = attacker_agent.compute_action_from_dict(
obs[attacker_agent_id],
infos[attacker_agent_id]["action_mask"])
obs[attacker_agent_id], infos[attacker_agent_id]["action_mask"]
)

if attacker_action[0]:
attacker_actions.append(int(attacker_action[1]))
attacker_node = sim.action_to_node(attacker_action)
attacker_actions.append(attacker_node.full_name)
if defender_action[0]:
defender_actions.append(int(defender_action[1]))
defender_node = sim.action_to_node(defender_action)
defender_actions.append(defender_node.full_name)

actions = {
'defender': defender_action,
'attacker': attacker_action
}
actions = {"defender": defender_action, "attacker": attacker_action}
obs, rewards, terminated, truncated, infos = sim.step(actions)

total_reward_defender += rewards.get(defender_agent_id, 0)
Expand All @@ -50,16 +51,62 @@ def test_bfs_vs_bfs_state_and_reward():
if terminated[defender_agent_id] or terminated[attacker_agent_id]:
break

assert attacker_actions == [328, 329, 353, 330, 354, 355, 356, 331, 357, 283, 332, 375, 358, 376, 377]
assert defender_actions == [68, 249, 324, 325, 349, 350, 396, 397, 421, 422, 423, 457, 0, 31, 88, 113, 144, 181, 212, 252, 276, 326, 327, 351, 352, 374]

for step_index in attacker_actions:
node = sim.attack_graph.get_node_by_id(sim._index_to_id[step_index])
assert attacker_actions == [
"Credentials:6:attemptCredentialsReuse",
"Credentials:6:credentialsReuse",
"Credentials:7:attemptCredentialsReuse",
"Credentials:6:attemptUse",
"Credentials:7:credentialsReuse",
"Credentials:7:attemptUse",
"Credentials:7:use",
"Credentials:7:attemptPropagateOneCredentialCompromised",
"Credentials:7:propagateOneCredentialCompromised",
"User:12:oneCredentialCompromised",
"User:12:passwordReuseCompromise",
"Credentials:9:attemptCredentialsReuse",
"Credentials:10:attemptCredentialsReuse",
"Credentials:9:credentialsReuse",
"Credentials:9:attemptUse",
]
assert defender_actions == [
"Program 1:notPresent",
"IDPS 1:effectiveness",
"Credentials:6:notDisclosed",
"Credentials:6:notGuessable",
"Credentials:7:notDisclosed",
"Credentials:7:notGuessable",
"Credentials:9:notDisclosed",
"Credentials:9:notGuessable",
"Credentials:10:notDisclosed",
"Credentials:10:notGuessable",
"Credentials:10:unique",
"User:12:noRemovableMediaUsage",
"OS App:notPresent",
"OS App:supplyChainAuditing",
"Program 1:supplyChainAuditing",
"Program 2:notPresent",
"Program 2:supplyChainAuditing",
"IDPS 1:notPresent",
"IDPS 1:supplyChainAuditing",
"SoftwareVulnerability:4:notPresent",
"Data:5:notPresent",
"Credentials:6:unique",
"Credentials:6:notPhishable",
"Credentials:7:unique",
"Credentials:7:notPhishable",
"Identity:8:notPresent",
]

for step_fullname in attacker_actions:
node = sim.attack_graph.get_node_by_full_name(step_fullname)
if node.is_compromised():
assert obs[defender_agent_id]['observed_state'][step_index]
node_index = sim.node_to_index(node)
assert obs[defender_agent_id]["observed_state"][node_index]

for step_index in defender_actions:
assert obs[defender_agent_id]['observed_state'][step_index]
for step_fullname in defender_actions:
node = sim.attack_graph.get_node_by_full_name(step_fullname)
node_index = sim.node_to_index(node)
assert obs[defender_agent_id]["observed_state"][node_index]

assert rewards[attacker_agent_id] == 0
assert rewards[defender_agent_id] == -31
Expand Down
Loading
Loading