Skip to content

Commit

Permalink
feat: query 11 implementation (#887)
Browse files Browse the repository at this point in the history
* Add query11

* Fixed query error
  • Loading branch information
montanarograziano authored Aug 30, 2024
1 parent b7707dd commit d9975f6
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
37 changes: 37 additions & 0 deletions tpch/execute/q11.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from pathlib import Path

import pandas as pd
import polars as pl
from queries import q11

pd.options.mode.copy_on_write = True
pd.options.future.infer_string = True

nation = Path("data") / "nation.parquet"
partsupp = Path("data") / "partsupp.parquet"
supplier = Path("data") / "supplier.parquet"

IO_FUNCS = {
"pandas": lambda x: pd.read_parquet(x, engine="pyarrow"),
"pandas[pyarrow]": lambda x: pd.read_parquet(
x, engine="pyarrow", dtype_backend="pyarrow"
),
"polars[eager]": lambda x: pl.read_parquet(x),
"polars[lazy]": lambda x: pl.scan_parquet(x),
}

tool = "pandas"
fn = IO_FUNCS[tool]
print(q11.query(fn(nation), fn(partsupp), fn(supplier)))

tool = "pandas[pyarrow]"
fn = IO_FUNCS[tool]
print(q11.query(fn(nation), fn(partsupp), fn(supplier)))

tool = "polars[eager]"
fn = IO_FUNCS[tool]
print(q11.query(fn(nation), fn(partsupp), fn(supplier)))

tool = "polars[lazy]"
fn = IO_FUNCS[tool]
print(q11.query(fn(nation), fn(partsupp), fn(supplier)).collect())
43 changes: 43 additions & 0 deletions tpch/queries/q11.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from datetime import datetime

import narwhals as nw
from narwhals.typing import FrameT


@nw.narwhalify
def query(
nation_ds_raw: FrameT,
partsupp_ds_raw: FrameT,
supplier_ds_raw: FrameT,
) -> FrameT:
var1 = datetime(1993, 10, 1)
var2 = datetime(1994, 1, 1)

nation_ds = nw.from_native(nation_ds_raw)
partsupp_ds = nw.from_native(partsupp_ds_raw)
supplier_ds = nw.from_native(supplier_ds_raw)

var1 = "GERMANY"
var2 = 0.0001

q1 = (
partsupp_ds.join(supplier_ds, left_on="ps_suppkey", right_on="s_suppkey")
.join(nation_ds, left_on="s_nationkey", right_on="n_nationkey")
.filter(nw.col("n_name") == var1)
)
q2 = q1.select(
(nw.col("ps_supplycost") * nw.col("ps_availqty")).sum().round(2).alias("tmp")
* var2
)

q_final = (
q1.with_columns((nw.col("ps_supplycost") * nw.col("ps_availqty")).alias("value"))
.group_by("ps_partkey")
.agg(nw.sum("value"))
.join(q2, how="cross")
.filter(nw.col("value") > nw.col("tmp"))
.select("ps_partkey", "value")
.sort("value", descending=True)
)

return nw.to_native(q_final)

0 comments on commit d9975f6

Please sign in to comment.