Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 39 additions & 33 deletions mal_gui/main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
from maltoolbox.language import LanguageGraph
from maltoolbox.model import Model, ModelAsset
from maltoolbox.exceptions import ModelException
from malsim.scenario import (
create_scenario_dict, save_scenario_dict, load_scenario_dict, Scenario
)
from malsim.scenario import Scenario

from .file_utils import image_path
from .model_scene import ModelScene
Expand Down Expand Up @@ -109,7 +107,7 @@ def load_scene(
self,
lang_file_path: str,
model: Model,
scenario_dict: Optional[dict[str, Any]] = None
scenario: Optional[Scenario] = None
):
"""Load scene with given language and model"""
print("LOADING SCENE!")
Expand All @@ -118,7 +116,7 @@ def load_scene(
lang_graph = LanguageGraph.load_from_file(lang_file_path)
self.asset_factory = self.create_asset_factory(lang_graph)
self.scene = self.create_scene(
lang_graph, self.asset_factory, model, scenario_dict
lang_graph, self.asset_factory, model, scenario
)

self.create_menu_bar()
Expand Down Expand Up @@ -171,12 +169,12 @@ def create_scene(
lang_graph: LanguageGraph,
asset_factory: AssetFactory,
model: Model,
scenario_dict: Optional[dict[str, Any]] = None
scenario: Optional[Scenario] = None
):
"""Create and initialize scene from language"""

model_scene = ModelScene(
asset_factory, lang_graph, model, self, scenario_dict
asset_factory, lang_graph, model, self, scenario
)

return model_scene
Expand Down Expand Up @@ -572,25 +570,9 @@ def load_model_or_scenario(self):

def load_scenario(self, file_path: str):
"""Load model and agents from a scenario"""
scenario_dict = load_scenario_dict(file_path)
lang_path = scenario_dict['lang_file']
lang_graph = (
LanguageGraph.load_from_file(lang_path)
)
if 'model_file' in scenario_dict:
model = Model.load_from_file(
scenario_dict['model_file'], lang_graph
)
elif 'model' in scenario_dict:
model = Model._from_dict(
scenario_dict['model'], lang_graph
)
else:
raise KeyError("Can not find model or model file in scenario")

scenario = Scenario.load_from_file(file_path)
# Reload in case language was changed
self.load_scene(lang_path, model, scenario_dict)
self.scene.scenario_dict = scenario_dict
self.load_scene(scenario._lang_file, scenario.model, scenario)
self.scenario_file_name = file_path

def load_model(self, file_path: str):
Expand Down Expand Up @@ -692,25 +674,49 @@ def save_as_scenario(self):
file_dialog.setAcceptMode(QFileDialog.AcceptSave)
file_dialog.setDefaultSuffix("yaml")
file_path, _ = file_dialog.getSaveFileName()
agents: dict[str, dict[str, Any]] = {}

agents = self.scene.scenario._agents_dict if self.scene.scenario else {}
# Add attacker agents from scene
for attacker_item in self.scene.attacker_items:
agents[attacker_item.name] = {
'type': 'attacker',
'entry_points': attacker_item.entry_points
}

settings = {}
# Only thing that can be changed by GUI for agents is entry points
if attacker_item.name in agents:
# If agent already exists in scenario, update entrypoints
agents[attacker_item.name]['entry_points'] = attacker_item.entry_points
else:
# Otherwise, add new agent to scenario agents dict
agents[attacker_item.name] = {
'entry_points': attacker_item.entry_points,
'type': 'attacker',
'agent_class': 'RandomAgent'
}

if not file_path:
print("No valid path detected for saving")
return

else:
self.add_positions_to_model()
# Create a new scenario based on settings in gui and save it to file
# TODO: this is a hacky solution, instead malsim scenario should be easier to work with
scenario = Scenario(
lang_file=self.lang_file_path,
model=self.scene.model,
agents=agents
agents=agents,
rewards=(
self.scene.scenario._rewards_dict if self.scene.scenario else None
),
false_negative_rates=(
self.scene.scenario._fnr_dict if self.scene.scenario else None
),
false_positive_rates=(
self.scene.scenario._fpr_dict if self.scene.scenario else None
),
is_actionable=(
self.scene.scenario._is_actionable_dict if self.scene.scenario else None
),
is_observable=(
self.scene.scenario._is_observable_dict if self.scene.scenario else None
)
)
scenario.save_to_file(file_path)

Expand Down
46 changes: 24 additions & 22 deletions mal_gui/model_scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from PySide6.QtCore import QLineF, Qt, QPointF, QRectF

from maltoolbox.model import Model
from malsim.scenario import AgentType

from .connection_item import AssociationConnectionItem,EntrypointConnectionItem
from .connection_dialog import AssociationConnectionDialog,EntrypointConnectionDialog
Expand All @@ -41,8 +42,8 @@
from object_explorer.asset_factory import AssetFactory
from .main_window import MainWindow
from maltoolbox.language import LanguageGraph
from maltoolbox.model import ModelAsset
from .connection_item import IConnectionItem
from malsim.scenario import Scenario

class ModelScene(QGraphicsScene):
def __init__(
Expand All @@ -51,7 +52,7 @@ def __init__(
lang_graph: LanguageGraph,
model: Model,
main_window: MainWindow,
scenario_dict: Optional[dict[str, Any]] = None
scenario: Optional[Scenario] = None
):
super().__init__()

Expand All @@ -64,7 +65,7 @@ def __init__(
# # instance model
self.lang_graph = lang_graph
self.model = model
self.scenario_dict = scenario_dict or {}
self.scenario = scenario

self._asset_id_to_item = {}
self.attacker_items: list[AttackerItem] = []
Expand Down Expand Up @@ -426,29 +427,30 @@ def draw_model(self):
)

# Draw attackers if they exists in scenario
agents = self.scenario_dict.get('agents', {})
for name, agent_info in agents.items():
if self.scenario:
agents = self.scenario._agents_dict
for name, agent_info in agents.items():

if agent_info['type'] != 'attacker':
continue

attacker_item = self.create_attacker(
QPointF(0, 0), name, agent_info['entry_points']
)
if agent_info['type'] != 'attacker':
continue

for entrypoint_full_name in agent_info['entry_points']:
attack_step = entrypoint_full_name.split(":")[-1]
asset_name = (
entrypoint_full_name.removesuffix(":" + attack_step)
)
asset = self.model.get_asset_by_name(asset_name)
assert asset, "Asset does not exist"
self.add_entrypoint_connection(
attack_step,
attacker_item,
self._asset_id_to_item[asset.id]
attacker_item = self.create_attacker(
QPointF(0, 0), name, agent_info['entry_points']
)

for entrypoint_full_name in agent_info['entry_points']:
attack_step = entrypoint_full_name.split(":")[-1]
asset_name = (
entrypoint_full_name.removesuffix(":" + attack_step)
)
asset = self.model.get_asset_by_name(asset_name)
assert asset, "Asset does not exist"
self.add_entrypoint_connection(
attack_step,
attacker_item,
self._asset_id_to_item[asset.id]
)


# based on connectionType use attacker or
# add_association_connection
Expand Down