diff --git a/mal_gui/main_window.py b/mal_gui/main_window.py index 8ee0069..803a78b 100644 --- a/mal_gui/main_window.py +++ b/mal_gui/main_window.py @@ -28,7 +28,9 @@ from maltoolbox.language import LanguageGraph from maltoolbox.model import Model, ModelAsset from maltoolbox.exceptions import ModelException +from malsim.config.agent_settings import AttackerSettings, AgentType from malsim.scenario import Scenario +from malsim.policies import RandomAgent from .file_utils import image_path from .model_scene import ModelScene @@ -675,21 +677,22 @@ def save_as_scenario(self): file_dialog.setDefaultSuffix("yaml") file_path, _ = file_dialog.getSaveFileName() - agents = self.scene.scenario._agents_dict if self.scene.scenario else {} + agents = self.scene.scenario.agent_settings if self.scene.scenario else {} # Add attacker agents from scene for attacker_item in self.scene.attacker_items: + agent = agents.get(attacker_item.name) # Only thing that can be changed by GUI for agents is entry points - if attacker_item.name in agents: + if isinstance(agent, AttackerSettings): # If agent already exists in scenario, update entrypoints - agents[attacker_item.name]['entry_points'] = attacker_item.entry_points + agent.entry_points = set(attacker_item.entry_points) else: # Otherwise, add new agent to scenario agents dict - agents[attacker_item.name] = { - 'entry_points': attacker_item.entry_points, - 'type': 'attacker', - 'agent_class': 'RandomAgent' - } - + agents[attacker_item.name] = AttackerSettings( + name=attacker_item.name, + entry_points=set(attacker_item.entry_points), + type=AgentType.ATTACKER, + policy=RandomAgent, + ) if not file_path: print("No valid path detected for saving") return @@ -698,25 +701,43 @@ def save_as_scenario(self): self.add_positions_to_model() # Create a new scenario based on settings in gui and save it to file # TODO: this is a hacky solution, instead malsim scenario should be easier to work with + rewards = None + false_negative_rates = None + false_positive_rates = None + is_actionable = None + is_observable = None + + if self.scene.scenario: + if self.scene.scenario.rewards: + rewards = ( + self.scene.scenario.rewards.to_dict() + ) + if self.scene.scenario.false_negative_rates: + false_negative_rates = ( + self.scene.scenario.false_negative_rates.to_dict() + ) + if self.scene.scenario.false_positive_rates: + false_positive_rates = ( + self.scene.scenario.false_positive_rates.to_dict() + ) + if self.scene.scenario.is_actionable: + is_actionable = ( + self.scene.scenario.is_actionable.to_dict() + ) + if self.scene.scenario.is_observable: + is_observable = ( + self.scene.scenario.is_observable.to_dict() + ) + scenario = Scenario( lang_file=self.lang_file_path, model=self.scene.model, - agents=agents, - rewards=( - self.scene.scenario._rewards_dict if self.scene.scenario else None - ), - false_negative_rates=( - self.scene.scenario._fnr_dict if self.scene.scenario else None - ), - false_positive_rates=( - self.scene.scenario._fpr_dict if self.scene.scenario else None - ), - is_actionable=( - self.scene.scenario._is_actionable_dict if self.scene.scenario else None - ), - is_observable=( - self.scene.scenario._is_observable_dict if self.scene.scenario else None - ) + agent_settings=agents, + rewards=rewards, + false_negative_rates=false_negative_rates, + false_positive_rates=false_positive_rates, + actionable_steps=is_actionable, + observable_steps=is_observable, ) scenario.save_to_file(file_path) diff --git a/mal_gui/model_scene.py b/mal_gui/model_scene.py index 173ce9e..f562935 100644 --- a/mal_gui/model_scene.py +++ b/mal_gui/model_scene.py @@ -17,7 +17,7 @@ from PySide6.QtCore import QLineF, Qt, QPointF, QRectF from maltoolbox.model import Model -from malsim.scenario import AgentType +from malsim.config.agent_settings import AttackerSettings from .connection_item import AssociationConnectionItem,EntrypointConnectionItem from .connection_dialog import AssociationConnectionDialog,EntrypointConnectionDialog @@ -428,29 +428,30 @@ def draw_model(self): # Draw attackers if they exists in scenario if self.scenario: - agents = self.scenario._agents_dict + agents = self.scenario.agent_settings for name, agent_info in agents.items(): - if agent_info['type'] != 'attacker': - continue - - attacker_item = self.create_attacker( - QPointF(0, 0), name, agent_info['entry_points'] - ) - - for entrypoint_full_name in agent_info['entry_points']: - attack_step = entrypoint_full_name.split(":")[-1] - asset_name = ( - entrypoint_full_name.removesuffix(":" + attack_step) - ) - asset = self.model.get_asset_by_name(asset_name) - assert asset, "Asset does not exist" - self.add_entrypoint_connection( - attack_step, - attacker_item, - self._asset_id_to_item[asset.id] + if isinstance(agent_info, AttackerSettings): + attacker_item = self.create_attacker( + QPointF(0, 0), name, agent_info.entry_points ) + for entry_point in agent_info.entry_points: + entrypoint_full_name = ( + entry_point if isinstance(entry_point, str) else entry_point.full_name + ) + attack_step = entrypoint_full_name.split(":")[-1] + asset_name = ( + entrypoint_full_name.removesuffix(":" + attack_step) + ) + asset = self.model.get_asset_by_name(asset_name) + assert asset, "Asset does not exist" + self.add_entrypoint_connection( + attack_step, + attacker_item, + self._asset_id_to_item[asset.id] + ) + # based on connectionType use attacker or # add_association_connection diff --git a/pyproject.toml b/pyproject.toml index a961649..bfeaa52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,8 +17,8 @@ dependencies = [ "PySide6_Addons~=6.8.1", "PySide6_Essentials~=6.8.1", "shiboken6~=6.8.1", - "mal-toolbox==1.*", - "mal-simulator==1.*", + "mal-toolbox==2.*", + "mal-simulator==2.*", "qt-material==2.14", "appdirs==1.4.4" ]