Skip to content

Commit

Permalink
Merge pull request #567 from opencybersecurityalliance/k2-apply-explain
Browse files Browse the repository at this point in the history
add EXPLAIN to APPLY
  • Loading branch information
subbyte authored Jul 28, 2024
2 parents a097cba + a597b4f commit f43e927
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 40 deletions.
19 changes: 17 additions & 2 deletions packages/kestrel_core/src/kestrel/analytics/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def my_analytic(df: pd.DataFrame, x: int = 0, y: float = 0.5)
from uuid import UUID

from kestrel.analytics.config import get_profile, load_profiles
from kestrel.display import GraphletExplanation
from kestrel.display import AnalyticOperation, GraphletExplanation
from kestrel.exceptions import (
AnalyticsError,
InvalidAnalytics,
Expand Down Expand Up @@ -156,6 +156,10 @@ def run(self, config: dict) -> DataFrame:
_logger.debug("python analytics job result:\n%s", df)
return df

def get_module_and_func_name(self, config: dict) -> str:
module_name, func_name = get_profile(self.analytic, config)
return module_name, func_name


class PythonAnalyticsInterface(AnalyticsInterface):
def __init__(
Expand Down Expand Up @@ -187,9 +191,20 @@ def store(
def explain_graph(
self,
graph: IRGraphEvaluable,
cache: MutableMapping[UUID, Any],
instructions_to_explain: Optional[Iterable[Instruction]] = None,
) -> Mapping[UUID, GraphletExplanation]:
raise NotImplementedError("PythonAnalyticsInterface.explain_graph") # TEMP
mapping = {}
if not instructions_to_explain:
instructions_to_explain = graph.get_sink_nodes()
for instruction in instructions_to_explain:
dep_graph = graph.duplicate_dependent_subgraph_of_node(instruction)
graph_dict = dep_graph.to_dict()
job = self._evaluate_instruction_in_graph(graph, cache, instruction)
module_name, func_name = job.get_module_and_func_name(self.config)
action = AnalyticOperation("Python", module_name + "::" + func_name)
mapping[instruction.id] = GraphletExplanation(graph_dict, action)
return mapping

def evaluate_graph(
self,
Expand Down
6 changes: 3 additions & 3 deletions packages/kestrel_core/src/kestrel/cache/inmemory.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def explain_graph(
instructions_to_explain: Optional[Iterable[Instruction]] = None,
) -> Mapping[UUID, GraphletExplanation]:
mapping = {}
if not instructions_to_evaluate:
instructions_to_evaluate = graph.get_sink_nodes()
for instruction in instructions_to_evaluate:
if not instructions_to_explain:
instructions_to_explain = graph.get_sink_nodes()
for instruction in instructions_to_explain:
dep_graph = graph.duplicate_dependent_subgraph_of_node(instruction)
graph_dict = dep_graph.to_dict()
query = NativeQuery("DataFrame", "")
Expand Down
10 changes: 9 additions & 1 deletion packages/kestrel_core/src/kestrel/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@ class NativeQuery(DataClassJSONMixin):
statement: str


@dataclass
class AnalyticOperation(DataClassJSONMixin):
# which interface
interface: str
# operation description
operation: str


@dataclass
class GraphletExplanation(DataClassJSONMixin):
# serialized IRGraph
graph: Mapping
# data source query
query: NativeQuery
action: Union[NativeQuery, AnalyticOperation]


@dataclass
Expand Down
6 changes: 6 additions & 0 deletions packages/kestrel_core/src/kestrel/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,12 @@ def _add_node(self, node: Instruction, deref: bool = True) -> Instruction:
self.store = node.store
return super()._add_node(node, deref)

def to_dict(self) -> Mapping[str, Iterable[Mapping]]:
d = super().to_dict()
d["interface"] = self.interface
d["store"] = self.store
return d


@typechecked
class IRGraphSimpleQuery(IRGraphEvaluable):
Expand Down
2 changes: 1 addition & 1 deletion packages/kestrel_core/tests/test_cache_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def test_explain_find_event_to_entity(process_creation_events):
assert len(rets) == 1
explanation = mapping[rets[0].id]
construct = graph.get_nodes_by_type(Construct)[0]
stmt = explanation.query.statement.replace('"', '')
stmt = explanation.action.statement.replace('"', '')
assert stmt == f"""WITH es AS
(SELECT DISTINCT *
FROM {construct.id.hex}),
Expand Down
18 changes: 9 additions & 9 deletions packages/kestrel_core/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from kestrel.config.internal import CACHE_INTERFACE_IDENTIFIER
from kestrel.display import GraphExplanation
from kestrel.frontend.parser import parse_kestrel_and_update_irgraph
from kestrel.ir.graph import IRGraph
from kestrel.ir.graph import IRGraph, IRGraphEvaluable
from kestrel.ir.instructions import Construct, SerializableDataFrame


Expand Down Expand Up @@ -200,10 +200,10 @@ def test_explain_in_cache():
assert isinstance(res, GraphExplanation)
assert len(res.graphlets) == 1
ge = res.graphlets[0]
assert ge.graph == session.irgraph.to_dict()
assert ge.graph == IRGraphEvaluable(session.irgraph).to_dict()
construct = session.irgraph.get_nodes_by_type(Construct)[0]
assert ge.query.language == "SQL"
stmt = ge.query.statement.replace('"', '')
assert ge.action.language == "SQL"
stmt = ge.action.statement.replace('"', '')
assert stmt == f"WITH proclist AS \n(SELECT DISTINCT * \nFROM {construct.id.hex}v), \nbrowsers AS \n(SELECT DISTINCT * \nFROM proclist \nWHERE name != 'cmd.exe'), \nchrome AS \n(SELECT DISTINCT * \nFROM browsers \nWHERE pid = 205)\n SELECT DISTINCT * \nFROM chrome"
with pytest.raises(StopIteration):
next(ress)
Expand Down Expand Up @@ -275,28 +275,28 @@ def schemes():

# DISP procs
assert len(disp.graphlets[0].graph["nodes"]) == 5
query = disp.graphlets[0].query.statement.replace('"', '')
query = disp.graphlets[0].action.statement.replace('"', '')
procs = session.irgraph.get_variable("procs")
c1 = next(session.irgraph.predecessors(procs))
assert query == f"WITH procs AS \n(SELECT DISTINCT * \nFROM {c1.id.hex}), \np2 AS \n(SELECT DISTINCT * \nFROM procs \nWHERE name IN ('firefox.exe', 'chrome.exe'))\n SELECT DISTINCT pid \nFROM p2"

# DISP nt
assert len(disp.graphlets[1].graph["nodes"]) == 2
query = disp.graphlets[1].query.statement.replace('"', '')
query = disp.graphlets[1].action.statement.replace('"', '')
nt = session.irgraph.get_variable("nt")
c2 = next(session.irgraph.predecessors(nt))
assert query == f"WITH nt AS \n(SELECT DISTINCT * \nFROM {c2.id.hex})\n SELECT DISTINCT * \nFROM nt"

# DISP domain
assert len(disp.graphlets[2].graph["nodes"]) == 2
query = disp.graphlets[2].query.statement.replace('"', '')
query = disp.graphlets[2].action.statement.replace('"', '')
domain = session.irgraph.get_variable("domain")
c3 = next(session.irgraph.predecessors(domain))
assert query == f"WITH domain AS \n(SELECT DISTINCT * \nFROM {c3.id.hex})\n SELECT DISTINCT * \nFROM domain"

# EXPLAIN d2
assert len(disp.graphlets[3].graph["nodes"]) == 11
query = disp.graphlets[3].query.statement.replace('"', '')
query = disp.graphlets[3].action.statement.replace('"', '')
p2 = session.irgraph.get_variable("p2")
p2pa = next(session.irgraph.successors(p2))
assert query == f"WITH ntx AS \n(SELECT DISTINCT * \nFROM {nt.id.hex}v \nWHERE abc IN (SELECT DISTINCT * \nFROM {p2pa.id.hex}v)), \nd2 AS \n(SELECT DISTINCT * \nFROM {domain.id.hex}v \nWHERE ip IN (SELECT DISTINCT destination \nFROM ntx))\n SELECT DISTINCT * \nFROM d2"
Expand Down Expand Up @@ -368,7 +368,7 @@ def test_explain_find_event_to_entity(process_creation_events):
session.irgraph = process_creation_events
res = session.execute("procs = FIND process RESPONDED es WHERE device.os = 'Linux' EXPLAIN procs")[0]
construct = session.irgraph.get_nodes_by_type(Construct)[0]
stmt = res.graphlets[0].query.statement.replace('"', '')
stmt = res.graphlets[0].action.statement.replace('"', '')
# cache.sql will use "*" as columns for __setitem__ in virtual cache
# so the result is different from test_cache_sqlite::test_explain_find_event_to_entity
assert stmt == f"WITH es AS \n(SELECT DISTINCT * \nFROM {construct.id.hex}v), \nprocs AS \n(SELECT DISTINCT * \nFROM es \nWHERE device.os = \'Linux\')\n SELECT DISTINCT * \nFROM procs"
6 changes: 3 additions & 3 deletions packages/kestrel_interface_sqlalchemy/tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def test_find_event_to_entity(setup_sqlite_ecs_process_creation):
evs, explain, procs = session.execute(huntflow)
assert evs.shape[0] == 9 # all events

stmt = explain.graphlets[0].query.statement
stmt = explain.graphlets[0].action.statement
test_dir = os.path.dirname(os.path.abspath(__file__))
result_file = os.path.join(test_dir, "result_interface_find_event_to_entity.txt")
with open(result_file) as h:
Expand All @@ -187,7 +187,7 @@ def test_find_entity_to_event(setup_sqlite_ecs_process_creation):
"""
explain, e2 = session.execute(huntflow)

stmt = explain.graphlets[0].query.statement
stmt = explain.graphlets[0].action.statement
test_dir = os.path.dirname(os.path.abspath(__file__))
result_file = os.path.join(test_dir, "result_interface_find_entity_to_event.txt")
with open(result_file) as h:
Expand Down Expand Up @@ -223,7 +223,7 @@ def test_find_entity_to_entity(setup_sqlite_ecs_process_creation):
"""
explain, parents = session.execute(huntflow)

stmt = explain.graphlets[0].query.statement
stmt = explain.graphlets[0].action.statement
test_dir = os.path.dirname(os.path.abspath(__file__))
result_file = os.path.join(test_dir, "result_interface_find_entity_to_entity.txt")
with open(result_file) as h:
Expand Down
51 changes: 30 additions & 21 deletions packages/kestrel_jupyter/src/kestrel_jupyter_kernel/display.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import base64
import tempfile
from io import BytesIO
from math import ceil, sqrt
from typing import Iterable, Mapping

import matplotlib.pyplot as plt
import networkx as nx
import numpy
import sqlparse
from kestrel.display import Display, GraphExplanation
from kestrel.display import AnalyticOperation, Display, GraphExplanation, NativeQuery
from kestrel.ir.graph import IRGraph
from kestrel.ir.instructions import Construct, DataSource, Instruction, Variable
from pandas import DataFrame
Expand Down Expand Up @@ -39,7 +40,10 @@ def to_html_blocks(d: Display) -> Iterable[str]:
elif isinstance(d, GraphExplanation):
for graphlet in d.graphlets:
graph = IRGraph(graphlet.graph)
plt.figure(figsize=(10, 8))
yield f"<h5>INTERFACE: {graphlet.graph['interface']}; STORE: {graphlet.graph['store']}</h5>"

fig_side_length = min(10, ceil(sqrt(len(graph))) + 1)
plt.figure(figsize=(fig_side_length, fig_side_length))
nx.draw(
graph,
with_labels=True,
Expand All @@ -48,24 +52,29 @@ def to_html_blocks(d: Display) -> Iterable[str]:
node_size=260,
node_color="#bfdff5",
)
with tempfile.NamedTemporaryFile(delete_on_close=False) as tf:
tf.close()
plt.savefig(tf.name, format="png")
with open(tf.name, "rb") as tfx:
data = tfx.read()

img = data_uri = base64.b64encode(data).decode("utf-8")
fig_buffer = BytesIO()
plt.savefig(fig_buffer, format="png")
img = data_uri = base64.b64encode(fig_buffer.getvalue()).decode("utf-8")
imgx = f'<img src="data:image/png;base64,{img}">'
yield imgx

query = graphlet.query.statement
if graphlet.query.language == "SQL":
lexer = SqlLexer()
query = sqlparse.format(query, reindent=True, keyword_case="upper")
elif graphlet.query.language == "KQL":
lexer = KustoLexer()
else:
lexer = guess_lexer(query)
query = highlight(query, lexer, HtmlFormatter())
style = "<style>" + HtmlFormatter().get_style_defs() + "</style>"
yield style + query
if isinstance(graphlet.action, NativeQuery):
native_query = graphlet.action
language = native_query.language
query = native_query.statement
if language == "SQL":
lexer = SqlLexer()
query = sqlparse.format(query, reindent=True, keyword_case="upper")
elif language == "KQL":
lexer = KustoLexer()
else:
lexer = guess_lexer(query)
query = highlight(query, lexer, HtmlFormatter())
style = "<style>" + HtmlFormatter().get_style_defs() + "</style>"
yield style + query
elif isinstance(graphlet.action, AnalyticOperation):
analytic_operation = graphlet.action
data = {
"Analytics": [analytic_operation.operation],
}
yield DataFrame(data).to_html(index=False)

0 comments on commit f43e927

Please sign in to comment.