diff --git a/malsim/__init__.py b/malsim/__init__.py index a45de526..dc74d0b5 100644 --- a/malsim/__init__.py +++ b/malsim/__init__.py @@ -23,8 +23,8 @@ run_simulation, ) -from malsim.scenario.scenario import Scenario -from malsim.config.agent_settings import AttackerSettings, DefenderSettings +from malsim.scenario import Scenario +from malsim.config import AttackerSettings, DefenderSettings __title__ = 'malsim' __version__ = '2.0.0' diff --git a/malsim/config/__init__.py b/malsim/config/__init__.py new file mode 100644 index 00000000..eadd09c2 --- /dev/null +++ b/malsim/config/__init__.py @@ -0,0 +1,12 @@ +from .agent_settings import AttackerSettings, DefenderSettings +from .sim_settings import MalSimulatorSettings, TTCMode, RewardMode +from .node_property_rule import NodePropertyRule + +__all__ = [ + 'AttackerSettings', + 'DefenderSettings', + 'MalSimulatorSettings', + 'NodePropertyRule', + 'RewardMode', + 'TTCMode', +] diff --git a/malsim/config/agent_settings.py b/malsim/config/agent_settings.py index f18dae5f..279eeae3 100644 --- a/malsim/config/agent_settings.py +++ b/malsim/config/agent_settings.py @@ -7,30 +7,6 @@ from maltoolbox.attackgraph import AttackGraphNode from malsim.config.node_property_rule import NodePropertyRule -from malsim.policies import ( - BreadthFirstAttacker, - DepthFirstAttacker, - KeyboardAgent, - PassiveAgent, - DefendCompromisedDefender, - DefendFutureCompromisedDefender, - RandomAgent, - TTCSoftMinAttacker, - ShortestPathAttacker, -) - -policy_name_to_class = { - 'KeyboardAgent': KeyboardAgent, - 'PassiveAgent': PassiveAgent, - 'DepthFirstAttacker': DepthFirstAttacker, - 'BreadthFirstAttacker': BreadthFirstAttacker, - 'TTCSoftMinAttacker': TTCSoftMinAttacker, - 'ShortestPathAttacker': ShortestPathAttacker, - 'DefendCompromisedDefender': DefendCompromisedDefender, - 'DefendFutureCompromisedDefender': DefendFutureCompromisedDefender, - 'RandomAgent': RandomAgent, -} - class AgentType(Enum): """Enum for agent types""" @@ -138,58 +114,6 @@ def to_dict(self) -> dict[str, Any]: return d -def agent_settings_from_dict( - name: str, - d: dict[str, Any], -) -> AttackerSettings | DefenderSettings: - """Load agent settings from a dict""" - - agent_type = AgentType(d['type']) - - # Resolve policy class if provided - policy = None - policy_name = d.get('policy') or d.get('agent_class') - if policy_name: - if policy_name not in policy_name_to_class: - raise LookupError( - f"Policy class '{policy_name}' not supported. " - f'Must be one of: {list(policy_name_to_class.keys())}' - ) - policy = policy_name_to_class[policy_name] - - config = d.get('config', {}) - - if agent_type == AgentType.ATTACKER: - return AttackerSettings( - name=name, - entry_points=set(d['entry_points']), - goals=set(d.get('goals', [])), - ttc_overrides=NodePropertyRule.from_optional_dict(d.get('ttc_overrides')), - policy=policy, - actionable_steps=NodePropertyRule.from_optional_dict( - d.get('actionable_steps') - ), - rewards=NodePropertyRule.from_optional_dict(d.get('rewards')), - config=config, - ) - - # Defender - return DefenderSettings( - name=name, - policy=policy, - observable_steps=NodePropertyRule.from_optional_dict(d.get('observable_steps')), - actionable_steps=NodePropertyRule.from_optional_dict(d.get('actionable_steps')), - rewards=NodePropertyRule.from_optional_dict(d.get('rewards')), - false_positive_rates=NodePropertyRule.from_optional_dict( - d.get('false_positive_rates') - ), - false_negative_rates=NodePropertyRule.from_optional_dict( - d.get('false_negative_rates') - ), - config=config, - ) - - def get_defender_settings( agent_settings: dict[str, DefenderSettings | AttackerSettings], ) -> dict[str, DefenderSettings]: diff --git a/malsim/config/agent_settings_factories.py b/malsim/config/agent_settings_factories.py new file mode 100644 index 00000000..b5a320fa --- /dev/null +++ b/malsim/config/agent_settings_factories.py @@ -0,0 +1,80 @@ +from typing import Any + +from malsim.config.agent_settings import AgentType, AttackerSettings, DefenderSettings + +from malsim.config.node_property_rule import NodePropertyRule +from malsim.policies import ( + BreadthFirstAttacker, + DepthFirstAttacker, + KeyboardAgent, + PassiveAgent, + DefendCompromisedDefender, + DefendFutureCompromisedDefender, + RandomAgent, + TTCSoftMinAttacker, + ShortestPathAttacker, +) + +policy_name_to_class = { + 'KeyboardAgent': KeyboardAgent, + 'PassiveAgent': PassiveAgent, + 'DepthFirstAttacker': DepthFirstAttacker, + 'BreadthFirstAttacker': BreadthFirstAttacker, + 'TTCSoftMinAttacker': TTCSoftMinAttacker, + 'ShortestPathAttacker': ShortestPathAttacker, + 'DefendCompromisedDefender': DefendCompromisedDefender, + 'DefendFutureCompromisedDefender': DefendFutureCompromisedDefender, + 'RandomAgent': RandomAgent, +} + + +def agent_settings_from_dict( + name: str, + d: dict[str, Any], +) -> AttackerSettings | DefenderSettings: + """Load agent settings from a dict""" + + agent_type = AgentType(d['type']) + + # Resolve policy class if provided + policy = None + policy_name = d.get('policy') or d.get('agent_class') + if policy_name: + if policy_name not in policy_name_to_class: + raise LookupError( + f"Policy class '{policy_name}' not supported. " + f'Must be one of: {list(policy_name_to_class.keys())}' + ) + policy = policy_name_to_class[policy_name] + + config = d.get('config', {}) + + if agent_type == AgentType.ATTACKER: + return AttackerSettings( + name=name, + entry_points=set(d['entry_points']), + goals=set(d.get('goals', [])), + ttc_overrides=NodePropertyRule.from_optional_dict(d.get('ttc_overrides')), + policy=policy, + actionable_steps=NodePropertyRule.from_optional_dict( + d.get('actionable_steps') + ), + rewards=NodePropertyRule.from_optional_dict(d.get('rewards')), + config=config, + ) + + # Defender + return DefenderSettings( + name=name, + policy=policy, + observable_steps=NodePropertyRule.from_optional_dict(d.get('observable_steps')), + actionable_steps=NodePropertyRule.from_optional_dict(d.get('actionable_steps')), + rewards=NodePropertyRule.from_optional_dict(d.get('rewards')), + false_positive_rates=NodePropertyRule.from_optional_dict( + d.get('false_positive_rates') + ), + false_negative_rates=NodePropertyRule.from_optional_dict( + d.get('false_negative_rates') + ), + config=config, + ) diff --git a/malsim/mal_simulator/attacker_state.py b/malsim/mal_simulator/attacker_state.py index 06edacf7..42b78708 100644 --- a/malsim/mal_simulator/attacker_state.py +++ b/malsim/mal_simulator/attacker_state.py @@ -5,7 +5,7 @@ from malsim.config.node_property_rule import NodePropertyRule from malsim.mal_simulator.agent_state import MalSimAgentState from malsim.mal_simulator.ttc_utils import TTCDist -from malsim.mal_simulator.types import AgentStates +from malsim.types import AgentStates @dataclass(frozen=True) diff --git a/malsim/mal_simulator/defender_state.py b/malsim/mal_simulator/defender_state.py index 494b8ad4..cc4da614 100644 --- a/malsim/mal_simulator/defender_state.py +++ b/malsim/mal_simulator/defender_state.py @@ -4,7 +4,7 @@ from maltoolbox.attackgraph import AttackGraphNode from malsim.config.node_property_rule import NodePropertyRule from malsim.mal_simulator.agent_state import MalSimAgentState -from malsim.mal_simulator.types import AgentStates +from malsim.types import AgentStates @dataclass(frozen=True) diff --git a/malsim/mal_simulator/defender_step.py b/malsim/mal_simulator/defender_step.py index 0867f041..b48ddc2f 100644 --- a/malsim/mal_simulator/defender_step.py +++ b/malsim/mal_simulator/defender_step.py @@ -10,7 +10,7 @@ from malsim.mal_simulator.simulator_state import MalSimulatorState if TYPE_CHECKING: - from malsim.mal_simulator.types import AgentStates + from malsim.types import AgentStates from malsim.mal_simulator.defender_state import MalSimDefenderState logger = logging.getLogger(__name__) diff --git a/malsim/mal_simulator/register_agent.py b/malsim/mal_simulator/register_agent.py index d7972c2e..a6fee867 100644 --- a/malsim/mal_simulator/register_agent.py +++ b/malsim/mal_simulator/register_agent.py @@ -12,7 +12,7 @@ from malsim.mal_simulator.rewards import attacker_step_reward, defender_step_reward from malsim.mal_simulator.simulator_state import MalSimulatorState from malsim.config.agent_settings import AttackerSettings, DefenderSettings -from malsim.mal_simulator.types import AgentRewards, AgentStates, AgentSettings +from malsim.types import AgentRewards, AgentStates, AgentSettings def register_attacker_settings( diff --git a/malsim/mal_simulator/reset_agent.py b/malsim/mal_simulator/reset_agent.py index 78606c45..72fdcf88 100644 --- a/malsim/mal_simulator/reset_agent.py +++ b/malsim/mal_simulator/reset_agent.py @@ -10,7 +10,7 @@ from malsim.mal_simulator.rewards import attacker_step_reward, defender_step_reward from malsim.mal_simulator.simulator_state import MalSimulatorState from malsim.config.agent_settings import get_defender_settings, get_attacker_settings -from malsim.mal_simulator.types import AgentRewards, AgentStates, AgentSettings +from malsim.types import AgentRewards, AgentStates, AgentSettings def reset_attackers( diff --git a/malsim/mal_simulator/run_simulation.py b/malsim/mal_simulator/run_simulation.py index 48749ba4..eb5e770c 100644 --- a/malsim/mal_simulator/run_simulation.py +++ b/malsim/mal_simulator/run_simulation.py @@ -4,7 +4,7 @@ from malsim.policies.decision_agent import DecisionAgent from malsim.mal_simulator.simulator import MalSimulator -from malsim.mal_simulator.types import AgentSettings +from malsim.types import AgentSettings def run_simulation( diff --git a/malsim/mal_simulator/simulator.py b/malsim/mal_simulator/simulator.py index e28e1a74..94b30f2a 100644 --- a/malsim/mal_simulator/simulator.py +++ b/malsim/mal_simulator/simulator.py @@ -38,7 +38,7 @@ node_false_positive_rate, ) from malsim.config.agent_settings import AttackerSettings, DefenderSettings -from malsim.mal_simulator.types import ( +from malsim.types import ( AgentRewards, AgentStates, AgentSettings, diff --git a/malsim/mal_simulator/state_query.py b/malsim/mal_simulator/state_query.py index ca78e38a..a2d56283 100644 --- a/malsim/mal_simulator/state_query.py +++ b/malsim/mal_simulator/state_query.py @@ -9,7 +9,7 @@ from malsim.mal_simulator.defender_state import get_defender_agents from malsim.mal_simulator.node_getters import full_name_or_node_to_node from malsim.config.sim_settings import TTCMode -from malsim.mal_simulator.types import AgentStates +from malsim.types import AgentStates def node_is_enabled_defense( diff --git a/malsim/scenario/scenario.py b/malsim/scenario/scenario.py index 0c8fde75..aa4b35a1 100644 --- a/malsim/scenario/scenario.py +++ b/malsim/scenario/scenario.py @@ -23,11 +23,11 @@ from maltoolbox.language import LanguageGraph from maltoolbox.attackgraph import create_attack_graph -from malsim.config.agent_settings import ( +from malsim.config import ( DefenderSettings, AttackerSettings, - agent_settings_from_dict, ) +from malsim.config.agent_settings_factories import agent_settings_from_dict from malsim.config.node_property_rule import NodePropertyRule diff --git a/malsim/mal_simulator/types.py b/malsim/types.py similarity index 100% rename from malsim/mal_simulator/types.py rename to malsim/types.py