Skip to content

Commit

Permalink
Finished adding first monkeypatch and singleton support to install th…
Browse files Browse the repository at this point in the history
…at monkeypatch. Got first real input+output example working with purely transpiled code. Other minor bumps.
  • Loading branch information
scnerd committed Jan 24, 2025
1 parent 708b9ff commit 731c90d
Show file tree
Hide file tree
Showing 29 changed files with 2,054 additions and 80 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ repos:
args: [--fix=lf]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.9.1
rev: v0.9.3
hooks:
# Run the linter.
- id: ruff
Expand All @@ -21,6 +21,6 @@ repos:
rev: v1.0
hooks:
- id: cargo-check
# - id: clippy
# args: [ --fix, --allow-dirty, --allow-staged, --no-deps ]
- id: clippy
args: [ --fix, --allow-dirty, --allow-staged, --no-deps ]
- id: fmt
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@ addopts = "--ignore=tests/test_sample_files_parse.py --ignore=tests/sample_data"
dev-dependencies = [
"maturin",
"pip",

"pytest",
"pytest-dependency",
"tabulate",
"pyparsing",
"pyyaml",
"ruff-api",
"jupyterlab",
"ipywidgets",
"pandas[pyarrow]>=2.2.3",
]
10 changes: 10 additions & 0 deletions python/spl_transpiler/runtime/commands/rename.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from pyspark.sql import DataFrame

from spl_transpiler.runtime.base import enforce_types, Expr


@enforce_types
def rename(df: DataFrame, **renames: Expr) -> DataFrame:
for new_name, old_expr in renames.items():
df = df.withColumnRenamed(old_expr.to_pyspark_expr(), new_name)
return df
10 changes: 10 additions & 0 deletions python/spl_transpiler/runtime/commands/spath.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from pyspark.sql import DataFrame

from spl_transpiler.runtime.base import enforce_types, Expr


@enforce_types
def spath(
df: DataFrame, path: str, *, input_: Expr = "raw", output: Expr | None = None
) -> DataFrame:
pass
3 changes: 2 additions & 1 deletion python/spl_transpiler/runtime/commands/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from spl_transpiler.runtime.base import enforce_types, Expr
from spl_transpiler.runtime.functions.stats import StatsFunction
from spl_transpiler.runtime.monkeypatches import groupByMaybeExploded


@enforce_types
Expand All @@ -16,7 +17,7 @@ def stats(
df, agg_expr = expr.to_pyspark_expr(df)
aggs.append(agg_expr.alias(label))

df = df.groupBy(*(v.to_pyspark_expr() for v in by))
df = groupByMaybeExploded(df, [v.to_pyspark_expr() for v in by])
df = df.agg(*aggs)

return df
19 changes: 19 additions & 0 deletions python/spl_transpiler/runtime/monkeypatches/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from functools import cache

from pyspark.sql import DataFrame, functions as F, GroupedData


def groupByMaybeExploded(self: DataFrame, by: list) -> GroupedData:
by_strings = [c for c in by if isinstance(c, str)]
return self.withColumns(
{
c: F.explode(c)
for c, tp in self.dtypes
if c in by_strings and str(tp).lower().startswith("array<")
}
).groupBy(by)


@cache
def install_monkeypatches():
DataFrame._spltranspiler__groupByMaybeExploded = groupByMaybeExploded
2 changes: 1 addition & 1 deletion src/commands/cmd_dedup/spl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl_pyclass!(DedupCommand {
// case Field(myVal) => !myVal.toLowerCase(Locale.ROOT).equals("sortby")
// }.rep(1)
fn dedup_field_rep(input: &str) -> IResult<&str, Vec<Field>> {
many1(ws(verify(field, |f| f.0.to_ascii_lowercase() != "sortby")))(input)
many1(ws(verify(field, |f| !f.0.eq_ignore_ascii_case("sortby"))))(input)
}

//
Expand Down
4 changes: 2 additions & 2 deletions src/commands/cmd_lookup/pyspark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl PipelineTransformer for LookupCommand {
None => {
vec![column_like!(col("main.*")), column_like!(col("lookup.*"))]
}
Some(LookupOutput { kv, fields }) if kv.to_ascii_uppercase() == "OUTPUT" => {
Some(LookupOutput { kv, fields }) if kv.eq_ignore_ascii_case("OUTPUT") => {
let mut cols = vec![column_like!(col("main.*"))];
for field in fields {
cols.push(match field {
Expand All @@ -60,7 +60,7 @@ impl PipelineTransformer for LookupCommand {
}
cols
}
Some(LookupOutput { kv, fields }) if kv.to_ascii_uppercase() == "OUTPUTNEW" => {
Some(LookupOutput { kv, fields }) if kv.eq_ignore_ascii_case("OUTPUTNEW") => {
bail!("UNIMPLEMENTED: `lookup` command with `OUTPUTNEW` not supported")
}
output => bail!("Unsupported output definition for `lookup`: {:?}", output),
Expand Down
4 changes: 2 additions & 2 deletions src/commands/cmd_lookup/spl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ impl_pyclass!(LookupCommand { dataset: String, fields: Vec<FieldLike>, output: O
pub fn field_rep(input: &str) -> IResult<&str, Vec<FieldLike>> {
comma_or_space_separated_list1(alt((
map(
verify(aliased_field, |v| v.alias.to_ascii_lowercase() != "output"),
verify(aliased_field, |v| !v.alias.eq_ignore_ascii_case("output")),
FieldLike::AliasedField,
),
map(
verify(field, |v| v.0.to_ascii_lowercase() != "output"),
verify(field, |v| !v.0.eq_ignore_ascii_case("output")),
FieldLike::Field,
),
)))
Expand Down
2 changes: 1 addition & 1 deletion src/commands/cmd_search/spl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl SplCommand<SearchCommand> for SearchParser {
mod tests {
use super::*;
use crate::spl::ast;
use crate::spl::parser::field_in;
use crate::spl::parser::{expr, field_in};
use crate::spl::utils::test::*;
use nom::Finish;
use rstest::rstest;
Expand Down
71 changes: 58 additions & 13 deletions src/commands/cmd_stats/pyspark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::pyspark::alias::Aliasable;
use crate::pyspark::ast::ColumnLike::FunctionCall;
use crate::pyspark::ast::*;
use crate::pyspark::base::{PysparkTranspileContext, ToSparkExpr};
use crate::pyspark::singletons::CodeTransformerType;
use crate::pyspark::transpiler::{PipelineTransformState, PipelineTransformer};
use crate::spl::ast;
use anyhow::bail;
Expand Down Expand Up @@ -101,16 +102,27 @@ impl PipelineTransformer for StatsCommand {
df = df_;
aggs.push(e);
}
df = df.group_by(
self.by
.clone()
.unwrap_or_default()
.into_iter()
.map(stats_utils::transform_by_field)
.collect(),
);
let by_columns: Vec<_> = self
.by
.clone()
.unwrap_or_default()
.into_iter()
.map(stats_utils::transform_by_field)
.map(|col| match col {
ColumnOrName::Column(col) => RuntimeExpr::from(col.unaliased()),
ColumnOrName::Name(name) => RuntimeExpr::from(column_like!(py_lit(name))),
})
.collect();

let df = df
.dataframe_method(
"_spltranspiler__groupByMaybeExploded",
vec![PyList(by_columns).into()],
Vec::new(),
)
.requires(CodeTransformerType::MonkeyPatch);

df = df.agg(aggs);
let df = df.agg(aggs);

Ok(state.with_df(df))
}
Expand All @@ -127,7 +139,11 @@ mod tests {
r#"stats
count min(_time) as firstTime max(_time) as lastTime
by Processes.parent_process_name Processes.parent_process Processes.process_name Processes.process_id Processes.process Processes.dest Processes.user"#,
r#"spark.table("main").groupBy([
r#"
from spl_transpiler.runtime.monkeypatches import install_monkeypatches
install_monkeypatches()
spark.table("main")._spltranspiler__groupByMaybeExploded([
"Processes.parent_process_name",
"Processes.parent_process",
"Processes.process_name",
Expand All @@ -149,7 +165,11 @@ mod tests {
r#"stats
count min(_time) as firstTime max(_time) as lastTime
by Web.http_user_agent, Web.status Web.http_method, Web.url, Web.url_length, Web.src, Web.dest, sourcetype"#,
r#"spark.table("main").groupBy([
r#"
from spl_transpiler.runtime.monkeypatches import install_monkeypatches
install_monkeypatches()
spark.table("main")._spltranspiler__groupByMaybeExploded([
"Web.http_user_agent",
"Web.status",
"Web.http_method",
Expand All @@ -175,7 +195,11 @@ mod tests {

generates(
query,
r#"spark.table("main").groupBy([
r#"
from spl_transpiler.runtime.monkeypatches import install_monkeypatches
install_monkeypatches()
spark.table("main")._spltranspiler__groupByMaybeExploded([
F.window("_time", "1 hours"),
"Processes.user",
"Processes.process_id",
Expand Down Expand Up @@ -204,7 +228,11 @@ mod tests {
// TODO: The `like` strings should be r strings or escaped
generates(
query,
r#"spark.table("main").groupBy([
r#"
from spl_transpiler.runtime.monkeypatches import install_monkeypatches
install_monkeypatches()
spark.table("main")._spltranspiler__groupByMaybeExploded([
"Processes.original_file_name",
"Processes.parent_process_name",
"Processes.parent_process",
Expand Down Expand Up @@ -254,4 +282,21 @@ df_1
"#,
)
}

#[rstest]
fn test_stats_6() {
let query = r#"stats
count
by _time"#;

generates(
query,
r#"
from spl_transpiler.runtime.monkeypatches import install_monkeypatches
install_monkeypatches()
spark.table("main")._spltranspiler__groupByMaybeExploded(["_time"]).agg(F.count(F.lit(1)).alias("count"))
"#,
)
}
}
4 changes: 2 additions & 2 deletions src/format_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ mod tests {
#[rstest]
fn test_format_python_code_1() {
assert_eq!(
format_python_code(r#"some_func( 'yolo')"#.to_string()).unwrap(),
format_python_code(r#"some_func( 'yolo')"#).unwrap(),
r#"some_func("yolo")"#.to_string(),
);
}

#[rstest]
fn test_format_python_code_2() {
assert_eq!(
format_python_code(r#"spark.table('main')"#.to_string()).unwrap(),
format_python_code(r#"spark.table('main')"#).unwrap(),
r#"spark.table("main")"#.to_string(),
);
}
Expand Down
Loading

0 comments on commit 731c90d

Please sign in to comment.