Skip to content

Commit

Permalink
Add PivotTable Transform (#630)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr authored Jul 30, 2024
1 parent c648061 commit 4d3b9e4
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions lumen/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,42 @@ def apply(self, table: DataFrame) -> DataFrame:
return pivot_table


class PivotTable(Transform):
"""
`PivotTable` applies pandas.pivot_table` to the data.
"""

values = param.ListSelector(default=[], doc="""
Column or columns to aggregate.""")

index = param.ListSelector(default=[], doc="""
Column, Grouper, array, or list of the previous
Keys to group by on the pivot table index. If a list is passed,
it can contain any of the other types (except list). If an array is
passed, it must be the same length as the data and will be used in
the same manner as column values.""")

columns = param.ListSelector(default=[], doc="""
Column, Grouper, array, or list of the previous
Keys to group by on the pivot table column. If a list is passed,
it can contain any of the other types (except list). If an array is
passed, it must be the same length as the data and will be used in
the same manner as column values.""")

aggfunc = param.String(default="mean", doc="""
Function, list of functions, dict, default 'mean'""")

_field_params: ClassVar[List[str]] = ['values', 'index', 'columns']

def apply(self, table: DataFrame) -> DataFrame:
values = self.values if len(self.values) > 1 else self.values[0]
columns = self.columns if len(self.columns) > 1 else self.columns[0]
return pd.pivot_table(
table, values=values, index=self.index, columns=columns,
aggfunc=self.aggfunc
)


class Melt(Transform):
"""
`Melt` applies the `pandas.melt` operation given the `id_vars` and `value_vars`.
Expand Down

0 comments on commit 4d3b9e4

Please sign in to comment.