diff --git a/malsim/scenario.py b/malsim/scenario.py index 0dfeb364..de49e4f4 100644 --- a/malsim/scenario.py +++ b/malsim/scenario.py @@ -97,8 +97,8 @@ def _validate_scenario_property_rules( # a way to lookup attack steps for asset types asset_type_step_names = { - asset_type.name: [a.name for a in asset_type.attack_steps] - for asset_type in graph.lang_graph.assets + asset_type_name: [step_name for step_name in asset_type.attack_steps] + for asset_type_name, asset_type in graph.lang_graph.assets.items() } if rules is None: diff --git a/malsim/sims/mal_simulator.py b/malsim/sims/mal_simulator.py index ba11d388..aec099c9 100644 --- a/malsim/sims/mal_simulator.py +++ b/malsim/sims/mal_simulator.py @@ -159,15 +159,20 @@ def create_blank_observation(self, default_obs_state=-1): right_field = getattr(assoc, right_field_name) for left_asset in left_field: for right_asset in right_field: + observation["model_edges_ids"].append( [ self._model_asset_id_to_index[left_asset.id], self._model_asset_id_to_index[right_asset.id] ] ) + observation["model_edges_type"].append( self._model_assoc_type_to_index[ - self._get_association_full_name(assoc)]) + # TODO: change this when lang_class factory not used anymore + assoc.__class__.__name__.removeprefix('Association_') + ] + ) np_obs = { @@ -336,11 +341,23 @@ def _format_info(self, info): @functools.lru_cache(maxsize=None) def observation_space(self, agent=None): # For now, an `object` is an attack step - num_assets = len(self.model.assets) + num_assets = len(self.attack_graph.model.assets) num_steps = len(self.attack_graph.nodes) num_lang_asset_types = len(self.lang_graph.assets) - num_lang_attack_steps = len(self.lang_graph.attack_steps) - num_lang_association_types = len(self.lang_graph.associations) + + unique_step_types = set() + for asset_type in self.lang_graph.assets.values(): + unique_step_types |= set(asset_type.attack_steps.values()) + num_lang_attack_steps = len(unique_step_types) + + unique_assoc_type_names = set() + for asset_type in self.lang_graph.assets.values(): + for assoc_type in asset_type.associations.values(): + unique_assoc_type_names.add( + assoc_type.full_name + ) + num_lang_association_types = len(unique_assoc_type_names) + num_attack_graph_edges = len( self._blank_observation["attack_graph_edges"]) num_model_edges = len( @@ -521,17 +538,30 @@ def _create_mapping_tables(self): # Lookup lists index to attribute self._index_to_id = [n.id for n in self.attack_graph.nodes] - self._index_to_full_name = [n.full_name - for n in self.attack_graph.nodes] - self._index_to_asset_type = [n.name for n in self.lang_graph.assets] - self._index_to_step_name = [n.asset.name + ":" + n.name - for n in self.lang_graph.attack_steps] - self._index_to_model_asset_id = [int(asset.id) for asset in \ - self.attack_graph.model.assets] - self._index_to_model_assoc_type = [assoc.name + '_' + \ - assoc.left_field.asset.name + '_' + \ - assoc.right_field.asset.name \ - for assoc in self.lang_graph.associations] + self._index_to_full_name = ( + [n.full_name for n in self.attack_graph.nodes] + ) + self._index_to_asset_type = ( + [n.name for n in self.lang_graph.assets.values()] + ) + + unique_step_type_names = { + n.full_name + for asset in self.lang_graph.assets.values() + for n in asset.attack_steps.values() + } + self._index_to_step_name = list(unique_step_type_names) + + self._index_to_model_asset_id = ( + [int(asset.id) for asset in self.attack_graph.model.assets] + ) + + unique_assoc_type_names = { + assoc.full_name + for asset in self.lang_graph.assets.values() + for assoc in asset.associations.values() + } + self._index_to_model_assoc_type = list(unique_assoc_type_names) # Lookup dicts attribute to index self._id_to_index = { @@ -604,55 +634,6 @@ def action_to_node( node = self.index_to_node(step_idx) return node - def _get_association_full_name(self, association) -> str: - """Get association full name - - TODO: Remove this method once the language graph integration is - complete in the mal-toolbox because the language graph associations - will use their full names for the name property - - Arguments: - association - the association whose full name will be returned - - Return: - A string containing the association name and the name of each of the - two asset types for the left and right fields separated by - underscores. - """ - - assoc_name = association.__class__.__name__ - if '_' in assoc_name: - # TODO: Not actually a to-do, but just an extra clarification that - # this is an ugly hack that will work for now until we get the - # unique association names. Right now some associations already - # use the asset types as part of their name if there are multiple - # associations with the same name. - return assoc_name - - left_field_name, right_field_name = \ - self.model.get_association_field_names(association) - left_field = getattr(association, left_field_name) - right_field = getattr(association, right_field_name) - lang_assoc = self.lang_graph.get_association_by_fields_and_assets( - left_field_name, - right_field_name, - left_field[0].type, - right_field[0].type - ) - if lang_assoc is None: - raise LookupError('Failed to find association for fields ' - '"%s" "%s" and asset types "%s" "%s"!' % ( - left_field_name, - right_field_name, - left_field[0].type, - right_field[0].type - ) - ) - assoc_full_name = lang_assoc.name + '_' + \ - lang_assoc.left_field.asset.name + '_' + \ - lang_assoc.right_field.asset.name - return assoc_full_name - def _initialize_agents(self) -> dict[str, list[int]]: """Initialize agent rewards, observations, and action surfaces diff --git a/pyproject.toml b/pyproject.toml index 2717c8e6..5d73b676 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ requires-python = ">=3.10" dependencies = [ "py2neo>=2021.2.3", "python-jsonschema-objects>=0.4.1", - "mal-toolbox==0.1.12", + "mal-toolbox==0.2", "numpy>=1.21.4", "pettingzoo>=1.24.2", "gymnasium==1.0", diff --git a/tests/test_example_scenarios.py b/tests/test_example_scenarios.py index 081af4dd..95461ac6 100644 --- a/tests/test_example_scenarios.py +++ b/tests/test_example_scenarios.py @@ -1,47 +1,48 @@ from malsim.scenario import create_simulator_from_scenario + def test_bfs_vs_bfs_state_and_reward(): sim, sim_config = create_simulator_from_scenario( - 'tests/testdata/scenarios/bfs_vs_bfs_scenario.yml') + "tests/testdata/scenarios/bfs_vs_bfs_scenario.yml" + ) obs, infos = sim.reset() attacker_agent_id = next(iter(sim.get_attacker_agents())) defender_agent_id = next(iter(sim.get_defender_agents()), None) # Initialize defender and attacker according to classes - defender_class = sim_config['agents'][defender_agent_id]['agent_class'] + defender_class = sim_config["agents"][defender_agent_id]["agent_class"] defender_agent = defender_class({}) - attacker_class = sim_config['agents'][attacker_agent_id]['agent_class'] + attacker_class = sim_config["agents"][attacker_agent_id]["agent_class"] attacker_agent = attacker_class({}) total_reward_defender = 0 total_reward_attacker = 0 attacker = sim.attack_graph.attackers[0] - attacker_actions = [sim._id_to_index[n.id] for n in attacker.entry_points] - defender_actions = [sim._id_to_index[n.id] - for n in sim.attack_graph.nodes - if n.is_enabled_defense()] + attacker_actions = [n.full_name for n in attacker.entry_points] + defender_actions = [ + n.full_name for n in sim.attack_graph.nodes if n.is_enabled_defense() + ] while True: defender_action = defender_agent.compute_action_from_dict( - obs[defender_agent_id], - infos[defender_agent_id]["action_mask"]) + obs[defender_agent_id], infos[defender_agent_id]["action_mask"] + ) attacker_action = attacker_agent.compute_action_from_dict( - obs[attacker_agent_id], - infos[attacker_agent_id]["action_mask"]) + obs[attacker_agent_id], infos[attacker_agent_id]["action_mask"] + ) if attacker_action[0]: - attacker_actions.append(int(attacker_action[1])) + attacker_node = sim.action_to_node(attacker_action) + attacker_actions.append(attacker_node.full_name) if defender_action[0]: - defender_actions.append(int(defender_action[1])) + defender_node = sim.action_to_node(defender_action) + defender_actions.append(defender_node.full_name) - actions = { - 'defender': defender_action, - 'attacker': attacker_action - } + actions = {"defender": defender_action, "attacker": attacker_action} obs, rewards, terminated, truncated, infos = sim.step(actions) total_reward_defender += rewards.get(defender_agent_id, 0) @@ -50,16 +51,62 @@ def test_bfs_vs_bfs_state_and_reward(): if terminated[defender_agent_id] or terminated[attacker_agent_id]: break - assert attacker_actions == [328, 329, 353, 330, 354, 355, 356, 331, 357, 283, 332, 375, 358, 376, 377] - assert defender_actions == [68, 249, 324, 325, 349, 350, 396, 397, 421, 422, 423, 457, 0, 31, 88, 113, 144, 181, 212, 252, 276, 326, 327, 351, 352, 374] - - for step_index in attacker_actions: - node = sim.attack_graph.get_node_by_id(sim._index_to_id[step_index]) + assert attacker_actions == [ + "Credentials:6:attemptCredentialsReuse", + "Credentials:6:credentialsReuse", + "Credentials:7:attemptCredentialsReuse", + "Credentials:6:attemptUse", + "Credentials:7:credentialsReuse", + "Credentials:7:attemptUse", + "Credentials:7:use", + "Credentials:7:attemptPropagateOneCredentialCompromised", + "Credentials:7:propagateOneCredentialCompromised", + "User:12:oneCredentialCompromised", + "User:12:passwordReuseCompromise", + "Credentials:9:attemptCredentialsReuse", + "Credentials:10:attemptCredentialsReuse", + "Credentials:9:credentialsReuse", + "Credentials:9:attemptUse", + ] + assert defender_actions == [ + "Program 1:notPresent", + "IDPS 1:effectiveness", + "Credentials:6:notDisclosed", + "Credentials:6:notGuessable", + "Credentials:7:notDisclosed", + "Credentials:7:notGuessable", + "Credentials:9:notDisclosed", + "Credentials:9:notGuessable", + "Credentials:10:notDisclosed", + "Credentials:10:notGuessable", + "Credentials:10:unique", + "User:12:noRemovableMediaUsage", + "OS App:notPresent", + "OS App:supplyChainAuditing", + "Program 1:supplyChainAuditing", + "Program 2:notPresent", + "Program 2:supplyChainAuditing", + "IDPS 1:notPresent", + "IDPS 1:supplyChainAuditing", + "SoftwareVulnerability:4:notPresent", + "Data:5:notPresent", + "Credentials:6:unique", + "Credentials:6:notPhishable", + "Credentials:7:unique", + "Credentials:7:notPhishable", + "Identity:8:notPresent", + ] + + for step_fullname in attacker_actions: + node = sim.attack_graph.get_node_by_full_name(step_fullname) if node.is_compromised(): - assert obs[defender_agent_id]['observed_state'][step_index] + node_index = sim.node_to_index(node) + assert obs[defender_agent_id]["observed_state"][node_index] - for step_index in defender_actions: - assert obs[defender_agent_id]['observed_state'][step_index] + for step_fullname in defender_actions: + node = sim.attack_graph.get_node_by_full_name(step_fullname) + node_index = sim.node_to_index(node) + assert obs[defender_agent_id]["observed_state"][node_index] assert rewards[attacker_agent_id] == 0 assert rewards[defender_agent_id] == -31 diff --git a/tests/testdata/models/run_demo_model.json b/tests/testdata/models/run_demo_model.json index 0d1d3f46..79737485 100644 --- a/tests/testdata/models/run_demo_model.json +++ b/tests/testdata/models/run_demo_model.json @@ -1,119 +1,120 @@ { - "metadata": { - "name": "Example Model", - "langVersion": "1.0.0", - "langID": "org.mal-lang.coreLang", - "malVersion": "0.1.0-SNAPSHOT", - "MAL-Toolbox Version": "0.1.6", - "info": "Created by the mal-toolbox model python module." - }, - "assets": { - "0": { - "name": "OS App", - "type": "Application" - }, - "1": { - "name": "Program 1", - "type": "Application" - }, - "2": { - "name": "SoftwareVulnerability:2", - "type": "SoftwareVulnerability" - }, - "3": { - "name": "Data:3", - "type": "Data" - }, - "4": { - "name": "Credentials:4", - "type": "Credentials" - }, - "5": { - "name": "Identity:5", - "type": "Identity" - }, - "6": { - "name": "ConnectionRule:6", - "type": "ConnectionRule" - }, - "7": { - "name": "Other OS App", - "type": "Application" - } - }, - "associations": [ - { - "AppExecution": { - "hostApp": [ - 0 - ], - "appExecutedApps": [ - 1 - ] - } - }, - { - "ApplicationVulnerability_SoftwareVulnerability_Application": { - "vulnerabilities": [ - 2 - ], - "application": [ - 0 - ] - } - }, - { - "AppContainment": { - "containedData": [ - 3 - ], - "containingApp": [ - 1 - ] - } - }, - { - "IdentityCredentials": { - "identities": [ - 5 - ], - "credentials": [ - 4 - ] - } - }, - { - "InfoContainment": { - "containerData": [ - 3 - ], - "information": [ - 4 - ] - } - }, - { - "ApplicationConnection": { - "applications": [ - 0, - 7 - ], - "appConnections": [ - 6 - ] - } - } - ], - "attackers": { - "8": { - "name": "Attacker:8", - "entry_points": { - "0": { - "attack_steps": [ - "networkConnectUninspected" - ] - } - } - } - } + "metadata": { + "name": "Example Model", + "langVersion": "1.0.0", + "langID": "org.mal-lang.coreLang", + "malVersion": "0.1.0-SNAPSHOT", + "MAL-Toolbox Version": "0.2.0", + "info": "Created by the mal-toolbox model python module." + }, + "assets": { + "0": { + "name": "OS App", + "type": "Application" + }, + "1": { + "name": "Program 1", + "type": "Application" + }, + "2": { + "name": "SoftwareVulnerability:2", + "type": "SoftwareVulnerability" + }, + "3": { + "name": "Data:3", + "type": "Data" + }, + "4": { + "name": "Credentials:4", + "type": "Credentials" + }, + "5": { + "name": "Identity:5", + "type": "Identity" + }, + "6": { + "name": "ConnectionRule:6", + "type": "ConnectionRule" + }, + "7": { + "name": "Other OS App", + "type": "Application" + } + }, + "associations": [ + { + "AppExecution": { + "hostApp": { + "0": "OS App" + }, + "appExecutedApps": { + "1": "Program 1" + } + } + }, + { + "ApplicationVulnerability": { + "vulnerabilities": { + "2": "SoftwareVulnerability:2" + }, + "application": { + "0": "OS App" + } + } + }, + { + "AppContainment": { + "containedData": { + "3": "Data:3" + }, + "containingApp": { + "1": "Program 1" + } + } + }, + { + "IdentityCredentials": { + "identities": { + "5": "Identity:5" + }, + "credentials": { + "4": "Credentials:4" + } + } + }, + { + "InfoContainment": { + "containerData": { + "3": "Data:3" + }, + "information": { + "4": "Credentials:4" + } + } + }, + { + "ApplicationConnection": { + "applications": { + "0": "OS App", + "7": "Other OS App" + }, + "appConnections": { + "6": "ConnectionRule:6" + } + } + } + ], + "attackers": { + "8": { + "name": "Attacker:8", + "entry_points": { + "OS App": { + "asset_id": 0, + "attack_steps": [ + "networkConnectUninspected" + ] + } + } + } + } } \ No newline at end of file diff --git a/tests/testdata/models/simple_no_attacker_test_model.yml b/tests/testdata/models/simple_no_attacker_test_model.yml index 3d9dc357..c6c473d9 100644 --- a/tests/testdata/models/simple_no_attacker_test_model.yml +++ b/tests/testdata/models/simple_no_attacker_test_model.yml @@ -57,56 +57,56 @@ assets: associations: - AppExecution: appExecutedApps: - - 1 - - 2 - - 3 + 1: Program 1 + 2: Program 2 + 3: IDPS 1 hostApp: - - 0 -- ApplicationVulnerability_SoftwareVulnerability_Application: + 0: OS App +- ApplicationVulnerability: application: - - 2 + 2: Program 2 vulnerabilities: - - 4 + 4: SoftwareVulnerability:4 - AppContainment: containedData: - - 5 + 5: Data:5 containingApp: - - 2 + 2: Program 2 - EncryptionCredentials: encryptCreds: - - 6 + 6: Credentials:6 encryptedData: - - 5 + 5: Data:5 - ConditionalAuthentication: credentials: - - 6 + 6: Credentials:6 requiredFactors: - - 7 + 7: Credentials:7 - IdentityCredentials: credentials: - - 6 + 6: Credentials:6 identities: - - 8 + 8: Identity:8 - IdentityCredentials: credentials: - - 9 - - 10 + 9: Credentials:9 + 10: Credentials:10 identities: - - 11 + 11: Identity:11 - UserAssignedIdentities: userIds: - - 8 - - 11 + 8: Identity:8 + 11: Identity:11 users: - - 12 -- Dependence_Information_Application: + 12: User:12 +- Dependence: dependentApps: - - 3 + 3: IDPS 1 infoDependedUpon: - - 13 + 13: Group:13 attackers: {} metadata: - MAL-Toolbox Version: 0.1.8 + MAL-Toolbox Version: 0.2.0 info: Created by the mal-toolbox model python module. langID: org.mal-lang.coreLang langVersion: 1.0.0 diff --git a/tests/testdata/models/simple_test_model.yml b/tests/testdata/models/simple_test_model.yml index 814ee9b0..ae52de33 100644 --- a/tests/testdata/models/simple_test_model.yml +++ b/tests/testdata/models/simple_test_model.yml @@ -57,62 +57,63 @@ assets: associations: - AppExecution: appExecutedApps: - - 1 - - 2 - - 3 + 1: Program 1 + 2: Program 2 + 3: IDPS 1 hostApp: - - 0 -- ApplicationVulnerability_SoftwareVulnerability_Application: + 0: OS App +- ApplicationVulnerability: application: - - 2 + 2: Program 2 vulnerabilities: - - 4 + 4: SoftwareVulnerability:4 - AppContainment: containedData: - - 5 + 5: Data:5 containingApp: - - 2 + 2: Program 2 - EncryptionCredentials: encryptCreds: - - 6 + 6: Credentials:6 encryptedData: - - 5 + 5: Data:5 - ConditionalAuthentication: credentials: - - 6 + 6: Credentials:6 requiredFactors: - - 7 + 7: Credentials:7 - IdentityCredentials: credentials: - - 6 + 6: Credentials:6 identities: - - 8 + 8: Identity:8 - IdentityCredentials: credentials: - - 9 - - 10 + 9: Credentials:9 + 10: Credentials:10 identities: - - 11 + 11: Identity:11 - UserAssignedIdentities: userIds: - - 8 - - 11 + 8: Identity:8 + 11: Identity:11 users: - - 12 -- Dependence_Information_Application: + 12: User:12 +- Dependence: dependentApps: - - 3 + 3: IDPS 1 infoDependedUpon: - - 13 + 13: Group:13 attackers: 15: entry_points: - 0: + OS App: + asset_id: 0 attack_steps: - fullAccess name: Attacker:15 metadata: - MAL-Toolbox Version: 0.1.8 + MAL-Toolbox Version: 0.2.0 info: Created by the mal-toolbox model python module. langID: org.mal-lang.coreLang langVersion: 1.0.0 diff --git a/tests/testdata/models/traininglang_model.yml b/tests/testdata/models/traininglang_model.yml index 9c0b1c57..91b0c388 100644 --- a/tests/testdata/models/traininglang_model.yml +++ b/tests/testdata/models/traininglang_model.yml @@ -14,40 +14,38 @@ assets: 4: name: Network:3 type: Network - associations: - - HostsInNetworks: hosts: - - 0 - - 1 + 0: Host:0 + 1: Host:1 networks: - - 4 + 4: Network:3 - UsersOnHosts: - users: - - 3 hosts: - - 0 + 0: Host:0 + users: + 3: User:3 - DataOnHosts: data: - - 2 + 2: Data:2 hosts: - - 0 - + 0: Host:0 attackers: 5: entry_points: - 0: + Host:0: + asset_id: 0 attack_steps: - connect - 3: + User:3: + asset_id: 3 attack_steps: - phishing name: Attacker1 - metadata: - MAL-Toolbox Version: 0.1.8 - info: Created manually by Joakim. + MAL-Toolbox Version: 0.2.0 + info: Created by the mal-toolbox model python module. langID: org.mal-lang.trainingLang langVersion: 1.0.0 malVersion: 0.1.0-SNAPSHOT