diff --git a/mal_gui/main_window.py b/mal_gui/main_window.py index a82859f..a8c7202 100644 --- a/mal_gui/main_window.py +++ b/mal_gui/main_window.py @@ -28,9 +28,7 @@ from maltoolbox.language import LanguageGraph from maltoolbox.model import Model, ModelAsset from maltoolbox.exceptions import ModelException -from malsim.scenario import ( - create_scenario_dict, save_scenario_dict, load_scenario_dict, Scenario -) +from malsim.scenario import Scenario from .file_utils import image_path from .model_scene import ModelScene @@ -109,7 +107,7 @@ def load_scene( self, lang_file_path: str, model: Model, - scenario_dict: Optional[dict[str, Any]] = None + scenario: Optional[Scenario] = None ): """Load scene with given language and model""" print("LOADING SCENE!") @@ -118,7 +116,7 @@ def load_scene( lang_graph = LanguageGraph.load_from_file(lang_file_path) self.asset_factory = self.create_asset_factory(lang_graph) self.scene = self.create_scene( - lang_graph, self.asset_factory, model, scenario_dict + lang_graph, self.asset_factory, model, scenario ) self.create_menu_bar() @@ -171,12 +169,12 @@ def create_scene( lang_graph: LanguageGraph, asset_factory: AssetFactory, model: Model, - scenario_dict: Optional[dict[str, Any]] = None + scenario: Optional[Scenario] = None ): """Create and initialize scene from language""" model_scene = ModelScene( - asset_factory, lang_graph, model, self, scenario_dict + asset_factory, lang_graph, model, self, scenario ) return model_scene @@ -572,25 +570,9 @@ def load_model_or_scenario(self): def load_scenario(self, file_path: str): """Load model and agents from a scenario""" - scenario_dict = load_scenario_dict(file_path) - lang_path = scenario_dict['lang_file'] - lang_graph = ( - LanguageGraph.load_from_file(lang_path) - ) - if 'model_file' in scenario_dict: - model = Model.load_from_file( - scenario_dict['model_file'], lang_graph - ) - elif 'model' in scenario_dict: - model = Model._from_dict( - scenario_dict['model'], lang_graph - ) - else: - raise KeyError("Can not find model or model file in scenario") - + scenario = Scenario.load_from_file(file_path) # Reload in case language was changed - self.load_scene(lang_path, model, scenario_dict) - self.scene.scenario_dict = scenario_dict + self.load_scene(scenario._lang_file, scenario.model, scenario) self.scenario_file_name = file_path def load_model(self, file_path: str): @@ -692,25 +674,49 @@ def save_as_scenario(self): file_dialog.setAcceptMode(QFileDialog.AcceptSave) file_dialog.setDefaultSuffix("yaml") file_path, _ = file_dialog.getSaveFileName() - agents: dict[str, dict[str, Any]] = {} + agents = self.scene.scenario._agents_dict if self.scene.scenario else {} + # Add attacker agents from scene for attacker_item in self.scene.attacker_items: - agents[attacker_item.name] = { - 'type': 'attacker', - 'entry_points': attacker_item.entry_points - } - - settings = {} + # Only thing that can be changed by GUI for agents is entry points + if attacker_item.name in agents: + # If agent already exists in scenario, update entrypoints + agents[attacker_item.name]['entry_points'] = attacker_item.entry_points + else: + # Otherwise, add new agent to scenario agents dict + agents[attacker_item.name] = { + 'entry_points': attacker_item.entry_points, + 'type': 'attacker', + 'agent_class': 'RandomAgent' + } if not file_path: print("No valid path detected for saving") return + else: self.add_positions_to_model() + # Create a new scenario based on settings in gui and save it to file + # TODO: this is a hacky solution, instead malsim scenario should be easier to work with scenario = Scenario( lang_file=self.lang_file_path, model=self.scene.model, - agents=agents + agents=agents, + rewards=( + self.scene.scenario._rewards_dict if self.scene.scenario else None + ), + false_negative_rates=( + self.scene.scenario._fnr_dict if self.scene.scenario else None + ), + false_positive_rates=( + self.scene.scenario._fpr_dict if self.scene.scenario else None + ), + is_actionable=( + self.scene.scenario._is_actionable_dict if self.scene.scenario else None + ), + is_observable=( + self.scene.scenario._is_observable_dict if self.scene.scenario else None + ) ) scenario.save_to_file(file_path) diff --git a/mal_gui/model_scene.py b/mal_gui/model_scene.py index d4fe7ea..173ce9e 100644 --- a/mal_gui/model_scene.py +++ b/mal_gui/model_scene.py @@ -17,6 +17,7 @@ from PySide6.QtCore import QLineF, Qt, QPointF, QRectF from maltoolbox.model import Model +from malsim.scenario import AgentType from .connection_item import AssociationConnectionItem,EntrypointConnectionItem from .connection_dialog import AssociationConnectionDialog,EntrypointConnectionDialog @@ -41,8 +42,8 @@ from object_explorer.asset_factory import AssetFactory from .main_window import MainWindow from maltoolbox.language import LanguageGraph - from maltoolbox.model import ModelAsset from .connection_item import IConnectionItem + from malsim.scenario import Scenario class ModelScene(QGraphicsScene): def __init__( @@ -51,7 +52,7 @@ def __init__( lang_graph: LanguageGraph, model: Model, main_window: MainWindow, - scenario_dict: Optional[dict[str, Any]] = None + scenario: Optional[Scenario] = None ): super().__init__() @@ -64,7 +65,7 @@ def __init__( # # instance model self.lang_graph = lang_graph self.model = model - self.scenario_dict = scenario_dict or {} + self.scenario = scenario self._asset_id_to_item = {} self.attacker_items: list[AttackerItem] = [] @@ -426,29 +427,30 @@ def draw_model(self): ) # Draw attackers if they exists in scenario - agents = self.scenario_dict.get('agents', {}) - for name, agent_info in agents.items(): + if self.scenario: + agents = self.scenario._agents_dict + for name, agent_info in agents.items(): - if agent_info['type'] != 'attacker': - continue - - attacker_item = self.create_attacker( - QPointF(0, 0), name, agent_info['entry_points'] - ) + if agent_info['type'] != 'attacker': + continue - for entrypoint_full_name in agent_info['entry_points']: - attack_step = entrypoint_full_name.split(":")[-1] - asset_name = ( - entrypoint_full_name.removesuffix(":" + attack_step) - ) - asset = self.model.get_asset_by_name(asset_name) - assert asset, "Asset does not exist" - self.add_entrypoint_connection( - attack_step, - attacker_item, - self._asset_id_to_item[asset.id] + attacker_item = self.create_attacker( + QPointF(0, 0), name, agent_info['entry_points'] ) + for entrypoint_full_name in agent_info['entry_points']: + attack_step = entrypoint_full_name.split(":")[-1] + asset_name = ( + entrypoint_full_name.removesuffix(":" + attack_step) + ) + asset = self.model.get_asset_by_name(asset_name) + assert asset, "Asset does not exist" + self.add_entrypoint_connection( + attack_step, + attacker_item, + self._asset_id_to_item[asset.id] + ) + # based on connectionType use attacker or # add_association_connection