Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Aug 19, 2024
1 parent 989a09e commit eeb7347
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 142 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/codspeed.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ jobs:
- name: Run benchmarks
uses: CodSpeedHQ/action@v3
with:
run: pytest tpch/scripts --codspeed
run: pytest tpch/benchmarks --codspeed
File renamed without changes.
120 changes: 120 additions & 0 deletions tpch/benchmarks/queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from __future__ import annotations

from datetime import date

import narwhals.stable.v1 as nw


def q1(lineitem: nw.LazyFrame) -> nw.DataFrame:
var_1 = date(1998, 9, 2)
query_result = (
lineitem.filter(nw.col("l_shipdate") <= var_1)
.with_columns(
disc_price=nw.col("l_extendedprice") * (1 - nw.col("l_discount")),
charge=(
nw.col("l_extendedprice")
* (1.0 - nw.col("l_discount"))
* (1.0 + nw.col("l_tax"))
),
)
.group_by(["l_returnflag", "l_linestatus"])
.agg(
[
nw.col("l_quantity").sum().alias("sum_qty"),
nw.col("l_extendedprice").sum().alias("sum_base_price"),
nw.col("disc_price").sum().alias("sum_disc_price"),
nw.col("charge").sum().alias("sum_charge"),
nw.col("l_quantity").mean().alias("avg_qty"),
nw.col("l_extendedprice").mean().alias("avg_price"),
nw.col("l_discount").mean().alias("avg_disc"),
nw.len().alias("count_order"),
],
)
.sort(["l_returnflag", "l_linestatus"])
)
return query_result.collect()


def q2(
region: nw.LazyFrame,
nation: nw.LazyFrame,
supplier: nw.LazyFrame,
part: nw.LazyFrame,
part_supp: nw.LazyFrame,
) -> nw.DataFrame:
var_1 = 15
var_2 = "BRASS"
var_3 = "EUROPE"

tmp = (
part.join(part_supp, left_on="p_partkey", right_on="ps_partkey")
.join(supplier, left_on="ps_suppkey", right_on="s_suppkey")
.join(nation, left_on="s_nationkey", right_on="n_nationkey")
.join(region, left_on="n_regionkey", right_on="r_regionkey")
.filter(
nw.col("p_size") == var_1,
nw.col("p_type").str.ends_with(var_2),
nw.col("r_name") == var_3,
)
)

final_cols = [
"s_acctbal",
"s_name",
"n_name",
"p_partkey",
"p_mfgr",
"s_address",
"s_phone",
"s_comment",
]

return (
tmp.group_by("p_partkey")
.agg(nw.col("ps_supplycost").min().alias("ps_supplycost"))
.join(
tmp,
left_on=["p_partkey", "ps_supplycost"],
right_on=["p_partkey", "ps_supplycost"],
)
.select(final_cols)
.sort(
["s_acctbal", "n_name", "s_name", "p_partkey"],
descending=[True, False, False, False],
)
.head(100)
.collect()
)


def q3(
customer: nw.LazyFrame, line_item: nw.LazyFrame, orders: nw.LazyFrame
) -> nw.DataFrame:
var_1 = var_2 = date(1995, 3, 15)
var_3 = "BUILDING"

return (
customer.filter(nw.col("c_mktsegment") == var_3)
.join(orders, left_on="c_custkey", right_on="o_custkey")
.join(line_item, left_on="o_orderkey", right_on="l_orderkey")
.filter(
nw.col("o_orderdate") < var_2,
nw.col("l_shipdate") > var_1,
)
.with_columns(
(nw.col("l_extendedprice") * (1 - nw.col("l_discount"))).alias("revenue")
)
.group_by(["o_orderkey", "o_orderdate", "o_shippriority"])
.agg([nw.sum("revenue")])
.select(
[
nw.col("o_orderkey").alias("l_orderkey"),
"revenue",
"o_orderdate",
"o_shippriority",
]
)
.sort(by=["revenue", "o_orderdate"], descending=[True, False])
.head(10)
.collect()
)
35 changes: 35 additions & 0 deletions tpch/benchmarks/queries_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

import pytest

import narwhals.stable.v1 as nw
from tpch.benchmarks.queries import q1
from tpch.benchmarks.queries import q2
from tpch.benchmarks.queries import q3
from tpch.benchmarks.utils import lib_to_reader

if TYPE_CHECKING:
from pytest_codspeed.plugin import BenchmarkFixture

DATA_FOLDER = Path("tests/data")


@pytest.mark.parametrize("library", ["pandas", "polars", "pyarrow", "dask"])
def test_queries(benchmark: BenchmarkFixture, library: str) -> None:
read_fn = lib_to_reader[library]

customer = nw.from_native(read_fn(DATA_FOLDER / "customer.parquet")).lazy()
lineitem = nw.from_native(read_fn(DATA_FOLDER / "lineitem.parquet")).lazy()
nation = nw.from_native(read_fn(DATA_FOLDER / "nation.parquet")).lazy()
orders = nw.from_native(read_fn(DATA_FOLDER / "orders.parquet")).lazy()
part = nw.from_native(read_fn(DATA_FOLDER / "part.parquet")).lazy()
partsupp = nw.from_native(read_fn(DATA_FOLDER / "partsupp.parquet")).lazy()
region = nw.from_native(read_fn(DATA_FOLDER / "region.parquet")).lazy()
supplier = nw.from_native(read_fn(DATA_FOLDER / "supplier.parquet")).lazy()

_ = benchmark(q1, lineitem)
_ = benchmark(q2, region, nation, supplier, part, partsupp)
_ = benchmark(q3, customer, lineitem, orders)
11 changes: 11 additions & 0 deletions tpch/benchmarks/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import dask.dataframe as dd
import pandas as pd
import polars as pl
import pyarrow.parquet as pq

lib_to_reader = {
"dask": lambda path: dd.read_parquet(path, dtype_backend="pyarrow"),
"pandas": pd.read_parquet,
"polars": pl.scan_parquet,
"pyarrow": pq.read_table,
}
Empty file removed tpch/scripts/__init__.py
Empty file.
57 changes: 0 additions & 57 deletions tpch/scripts/q1_test.py

This file was deleted.

84 changes: 0 additions & 84 deletions tpch/scripts/q2_test.py

This file was deleted.

0 comments on commit eeb7347

Please sign in to comment.