diff --git a/packages/kestrel_core/src/kestrel/analytics/interface.py b/packages/kestrel_core/src/kestrel/analytics/interface.py index fe3cf578..d1f44b8a 100644 --- a/packages/kestrel_core/src/kestrel/analytics/interface.py +++ b/packages/kestrel_core/src/kestrel/analytics/interface.py @@ -111,7 +111,7 @@ def my_analytic(df: pd.DataFrame, x: int = 0, y: float = 0.5) from uuid import UUID from kestrel.analytics.config import get_profile, load_profiles -from kestrel.display import GraphletExplanation +from kestrel.display import AnalyticOperation, GraphletExplanation from kestrel.exceptions import ( AnalyticsError, InvalidAnalytics, @@ -156,6 +156,10 @@ def run(self, config: dict) -> DataFrame: _logger.debug("python analytics job result:\n%s", df) return df + def get_module_and_func_name(self, config: dict) -> str: + module_name, func_name = get_profile(self.analytic, config) + return module_name, func_name + class PythonAnalyticsInterface(AnalyticsInterface): def __init__( @@ -187,9 +191,20 @@ def store( def explain_graph( self, graph: IRGraphEvaluable, + cache: MutableMapping[UUID, Any], instructions_to_explain: Optional[Iterable[Instruction]] = None, ) -> Mapping[UUID, GraphletExplanation]: - raise NotImplementedError("PythonAnalyticsInterface.explain_graph") # TEMP + mapping = {} + if not instructions_to_explain: + instructions_to_explain = graph.get_sink_nodes() + for instruction in instructions_to_explain: + dep_graph = graph.duplicate_dependent_subgraph_of_node(instruction) + graph_dict = dep_graph.to_dict() + job = self._evaluate_instruction_in_graph(graph, cache, instruction) + module_name, func_name = job.get_module_and_func_name(self.config) + action = AnalyticOperation("Python", module_name + "::" + func_name) + mapping[instruction.id] = GraphletExplanation(graph_dict, action) + return mapping def evaluate_graph( self, diff --git a/packages/kestrel_core/src/kestrel/cache/inmemory.py b/packages/kestrel_core/src/kestrel/cache/inmemory.py index 058ae7d9..6a902e8c 100644 --- a/packages/kestrel_core/src/kestrel/cache/inmemory.py +++ b/packages/kestrel_core/src/kestrel/cache/inmemory.py @@ -80,9 +80,9 @@ def explain_graph( instructions_to_explain: Optional[Iterable[Instruction]] = None, ) -> Mapping[UUID, GraphletExplanation]: mapping = {} - if not instructions_to_evaluate: - instructions_to_evaluate = graph.get_sink_nodes() - for instruction in instructions_to_evaluate: + if not instructions_to_explain: + instructions_to_explain = graph.get_sink_nodes() + for instruction in instructions_to_explain: dep_graph = graph.duplicate_dependent_subgraph_of_node(instruction) graph_dict = dep_graph.to_dict() query = NativeQuery("DataFrame", "") diff --git a/packages/kestrel_core/src/kestrel/display.py b/packages/kestrel_core/src/kestrel/display.py index 0cc5175a..649e4ed4 100644 --- a/packages/kestrel_core/src/kestrel/display.py +++ b/packages/kestrel_core/src/kestrel/display.py @@ -13,12 +13,20 @@ class NativeQuery(DataClassJSONMixin): statement: str +@dataclass +class AnalyticOperation(DataClassJSONMixin): + # which interface + interface: str + # operation description + operation: str + + @dataclass class GraphletExplanation(DataClassJSONMixin): # serialized IRGraph graph: Mapping # data source query - query: NativeQuery + action: Union[NativeQuery, AnalyticOperation] @dataclass diff --git a/packages/kestrel_core/src/kestrel/ir/graph.py b/packages/kestrel_core/src/kestrel/ir/graph.py index 6e4e2cbb..145e977c 100644 --- a/packages/kestrel_core/src/kestrel/ir/graph.py +++ b/packages/kestrel_core/src/kestrel/ir/graph.py @@ -863,6 +863,12 @@ def _add_node(self, node: Instruction, deref: bool = True) -> Instruction: self.store = node.store return super()._add_node(node, deref) + def to_dict(self) -> Mapping[str, Iterable[Mapping]]: + d = super().to_dict() + d["interface"] = self.interface + d["store"] = self.store + return d + @typechecked class IRGraphSimpleQuery(IRGraphEvaluable): diff --git a/packages/kestrel_core/tests/test_cache_sqlite.py b/packages/kestrel_core/tests/test_cache_sqlite.py index 901cd580..91528c74 100644 --- a/packages/kestrel_core/tests/test_cache_sqlite.py +++ b/packages/kestrel_core/tests/test_cache_sqlite.py @@ -291,7 +291,7 @@ def test_explain_find_event_to_entity(process_creation_events): assert len(rets) == 1 explanation = mapping[rets[0].id] construct = graph.get_nodes_by_type(Construct)[0] - stmt = explanation.query.statement.replace('"', '') + stmt = explanation.action.statement.replace('"', '') assert stmt == f"""WITH es AS (SELECT DISTINCT * FROM {construct.id.hex}), diff --git a/packages/kestrel_core/tests/test_session.py b/packages/kestrel_core/tests/test_session.py index ede527d1..0b1b4fc1 100644 --- a/packages/kestrel_core/tests/test_session.py +++ b/packages/kestrel_core/tests/test_session.py @@ -10,7 +10,7 @@ from kestrel.config.internal import CACHE_INTERFACE_IDENTIFIER from kestrel.display import GraphExplanation from kestrel.frontend.parser import parse_kestrel_and_update_irgraph -from kestrel.ir.graph import IRGraph +from kestrel.ir.graph import IRGraph, IRGraphEvaluable from kestrel.ir.instructions import Construct, SerializableDataFrame @@ -200,10 +200,10 @@ def test_explain_in_cache(): assert isinstance(res, GraphExplanation) assert len(res.graphlets) == 1 ge = res.graphlets[0] - assert ge.graph == session.irgraph.to_dict() + assert ge.graph == IRGraphEvaluable(session.irgraph).to_dict() construct = session.irgraph.get_nodes_by_type(Construct)[0] - assert ge.query.language == "SQL" - stmt = ge.query.statement.replace('"', '') + assert ge.action.language == "SQL" + stmt = ge.action.statement.replace('"', '') assert stmt == f"WITH proclist AS \n(SELECT DISTINCT * \nFROM {construct.id.hex}v), \nbrowsers AS \n(SELECT DISTINCT * \nFROM proclist \nWHERE name != 'cmd.exe'), \nchrome AS \n(SELECT DISTINCT * \nFROM browsers \nWHERE pid = 205)\n SELECT DISTINCT * \nFROM chrome" with pytest.raises(StopIteration): next(ress) @@ -275,28 +275,28 @@ def schemes(): # DISP procs assert len(disp.graphlets[0].graph["nodes"]) == 5 - query = disp.graphlets[0].query.statement.replace('"', '') + query = disp.graphlets[0].action.statement.replace('"', '') procs = session.irgraph.get_variable("procs") c1 = next(session.irgraph.predecessors(procs)) assert query == f"WITH procs AS \n(SELECT DISTINCT * \nFROM {c1.id.hex}), \np2 AS \n(SELECT DISTINCT * \nFROM procs \nWHERE name IN ('firefox.exe', 'chrome.exe'))\n SELECT DISTINCT pid \nFROM p2" # DISP nt assert len(disp.graphlets[1].graph["nodes"]) == 2 - query = disp.graphlets[1].query.statement.replace('"', '') + query = disp.graphlets[1].action.statement.replace('"', '') nt = session.irgraph.get_variable("nt") c2 = next(session.irgraph.predecessors(nt)) assert query == f"WITH nt AS \n(SELECT DISTINCT * \nFROM {c2.id.hex})\n SELECT DISTINCT * \nFROM nt" # DISP domain assert len(disp.graphlets[2].graph["nodes"]) == 2 - query = disp.graphlets[2].query.statement.replace('"', '') + query = disp.graphlets[2].action.statement.replace('"', '') domain = session.irgraph.get_variable("domain") c3 = next(session.irgraph.predecessors(domain)) assert query == f"WITH domain AS \n(SELECT DISTINCT * \nFROM {c3.id.hex})\n SELECT DISTINCT * \nFROM domain" # EXPLAIN d2 assert len(disp.graphlets[3].graph["nodes"]) == 11 - query = disp.graphlets[3].query.statement.replace('"', '') + query = disp.graphlets[3].action.statement.replace('"', '') p2 = session.irgraph.get_variable("p2") p2pa = next(session.irgraph.successors(p2)) assert query == f"WITH ntx AS \n(SELECT DISTINCT * \nFROM {nt.id.hex}v \nWHERE abc IN (SELECT DISTINCT * \nFROM {p2pa.id.hex}v)), \nd2 AS \n(SELECT DISTINCT * \nFROM {domain.id.hex}v \nWHERE ip IN (SELECT DISTINCT destination \nFROM ntx))\n SELECT DISTINCT * \nFROM d2" @@ -368,7 +368,7 @@ def test_explain_find_event_to_entity(process_creation_events): session.irgraph = process_creation_events res = session.execute("procs = FIND process RESPONDED es WHERE device.os = 'Linux' EXPLAIN procs")[0] construct = session.irgraph.get_nodes_by_type(Construct)[0] - stmt = res.graphlets[0].query.statement.replace('"', '') + stmt = res.graphlets[0].action.statement.replace('"', '') # cache.sql will use "*" as columns for __setitem__ in virtual cache # so the result is different from test_cache_sqlite::test_explain_find_event_to_entity assert stmt == f"WITH es AS \n(SELECT DISTINCT * \nFROM {construct.id.hex}v), \nprocs AS \n(SELECT DISTINCT * \nFROM es \nWHERE device.os = \'Linux\')\n SELECT DISTINCT * \nFROM procs" diff --git a/packages/kestrel_interface_sqlalchemy/tests/test_interface.py b/packages/kestrel_interface_sqlalchemy/tests/test_interface.py index 8e113a6e..098c3df6 100644 --- a/packages/kestrel_interface_sqlalchemy/tests/test_interface.py +++ b/packages/kestrel_interface_sqlalchemy/tests/test_interface.py @@ -165,7 +165,7 @@ def test_find_event_to_entity(setup_sqlite_ecs_process_creation): evs, explain, procs = session.execute(huntflow) assert evs.shape[0] == 9 # all events - stmt = explain.graphlets[0].query.statement + stmt = explain.graphlets[0].action.statement test_dir = os.path.dirname(os.path.abspath(__file__)) result_file = os.path.join(test_dir, "result_interface_find_event_to_entity.txt") with open(result_file) as h: @@ -187,7 +187,7 @@ def test_find_entity_to_event(setup_sqlite_ecs_process_creation): """ explain, e2 = session.execute(huntflow) - stmt = explain.graphlets[0].query.statement + stmt = explain.graphlets[0].action.statement test_dir = os.path.dirname(os.path.abspath(__file__)) result_file = os.path.join(test_dir, "result_interface_find_entity_to_event.txt") with open(result_file) as h: @@ -223,7 +223,7 @@ def test_find_entity_to_entity(setup_sqlite_ecs_process_creation): """ explain, parents = session.execute(huntflow) - stmt = explain.graphlets[0].query.statement + stmt = explain.graphlets[0].action.statement test_dir = os.path.dirname(os.path.abspath(__file__)) result_file = os.path.join(test_dir, "result_interface_find_entity_to_entity.txt") with open(result_file) as h: diff --git a/packages/kestrel_jupyter/src/kestrel_jupyter_kernel/display.py b/packages/kestrel_jupyter/src/kestrel_jupyter_kernel/display.py index a979392f..321cf9b8 100644 --- a/packages/kestrel_jupyter/src/kestrel_jupyter_kernel/display.py +++ b/packages/kestrel_jupyter/src/kestrel_jupyter_kernel/display.py @@ -1,12 +1,13 @@ import base64 -import tempfile +from io import BytesIO +from math import ceil, sqrt from typing import Iterable, Mapping import matplotlib.pyplot as plt import networkx as nx import numpy import sqlparse -from kestrel.display import Display, GraphExplanation +from kestrel.display import AnalyticOperation, Display, GraphExplanation, NativeQuery from kestrel.ir.graph import IRGraph from kestrel.ir.instructions import Construct, DataSource, Instruction, Variable from pandas import DataFrame @@ -39,7 +40,10 @@ def to_html_blocks(d: Display) -> Iterable[str]: elif isinstance(d, GraphExplanation): for graphlet in d.graphlets: graph = IRGraph(graphlet.graph) - plt.figure(figsize=(10, 8)) + yield f"
INTERFACE: {graphlet.graph['interface']}; STORE: {graphlet.graph['store']}
" + + fig_side_length = min(10, ceil(sqrt(len(graph))) + 1) + plt.figure(figsize=(fig_side_length, fig_side_length)) nx.draw( graph, with_labels=True, @@ -48,24 +52,29 @@ def to_html_blocks(d: Display) -> Iterable[str]: node_size=260, node_color="#bfdff5", ) - with tempfile.NamedTemporaryFile(delete_on_close=False) as tf: - tf.close() - plt.savefig(tf.name, format="png") - with open(tf.name, "rb") as tfx: - data = tfx.read() - - img = data_uri = base64.b64encode(data).decode("utf-8") + fig_buffer = BytesIO() + plt.savefig(fig_buffer, format="png") + img = data_uri = base64.b64encode(fig_buffer.getvalue()).decode("utf-8") imgx = f'' yield imgx - query = graphlet.query.statement - if graphlet.query.language == "SQL": - lexer = SqlLexer() - query = sqlparse.format(query, reindent=True, keyword_case="upper") - elif graphlet.query.language == "KQL": - lexer = KustoLexer() - else: - lexer = guess_lexer(query) - query = highlight(query, lexer, HtmlFormatter()) - style = "" - yield style + query + if isinstance(graphlet.action, NativeQuery): + native_query = graphlet.action + language = native_query.language + query = native_query.statement + if language == "SQL": + lexer = SqlLexer() + query = sqlparse.format(query, reindent=True, keyword_case="upper") + elif language == "KQL": + lexer = KustoLexer() + else: + lexer = guess_lexer(query) + query = highlight(query, lexer, HtmlFormatter()) + style = "" + yield style + query + elif isinstance(graphlet.action, AnalyticOperation): + analytic_operation = graphlet.action + data = { + "Analytics": [analytic_operation.operation], + } + yield DataFrame(data).to_html(index=False)