diff --git a/internal/tests/integration/functions_test.go b/internal/tests/integration/functions_test.go index 1f89923..daf7b3c 100644 --- a/internal/tests/integration/functions_test.go +++ b/internal/tests/integration/functions_test.go @@ -38,3 +38,18 @@ func TestIntegration_BuiltinFunctions(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 10, len(res)) } + +func TestIntegration_ColumnGetItem(t *testing.T) { + ctx := context.Background() + spark, err := sql.NewSessionBuilder().Remote("sc://localhost").Build(ctx) + if err != nil { + t.Fatal(err) + } + + df, _ := spark.Sql(ctx, "select sequence(1,10) as s") + df, err = df.Select(ctx, functions.Col("s").GetItem(2)) + assert.NoError(t, err) + res, err := df.Collect(ctx) + assert.NoError(t, err) + assert.Equal(t, int32(3), res[0].Values()[0]) +} diff --git a/spark/sql/column/column.go b/spark/sql/column/column.go index ce70189..b0b3b87 100644 --- a/spark/sql/column/column.go +++ b/spark/sql/column/column.go @@ -77,6 +77,10 @@ func (c Column) Desc() Column { }) } +func (c Column) GetItem(key any) Column { + return NewColumn(NewUnresolvedExtractValue("getItem", c.expr, NewLiteral(key))) +} + func (c Column) Asc() Column { return NewColumn(&sortExpression{ child: c.expr, diff --git a/spark/sql/column/expressions.go b/spark/sql/column/expressions.go index fe50d16..8538ab5 100644 --- a/spark/sql/column/expressions.go +++ b/spark/sql/column/expressions.go @@ -164,6 +164,37 @@ func (c *caseWhenExpression) ToProto(ctx context.Context) (*proto.Expression, er return fun.ToProto(ctx) } +type unresolvedExtractValue struct { + name string + child expression + extraction expression +} + +func (u *unresolvedExtractValue) DebugString() string { + return fmt.Sprintf("%s(%s, %s)", u.name, u.child.DebugString(), u.extraction.DebugString()) +} + +func (u *unresolvedExtractValue) ToProto(ctx context.Context) (*proto.Expression, error) { + expr := newProtoExpression() + child, err := u.child.ToProto(ctx) + if err != nil { + return nil, err + } + + extraction, err := u.extraction.ToProto(ctx) + if err != nil { + return nil, err + } + + expr.ExprType = &proto.Expression_UnresolvedExtractValue_{ + UnresolvedExtractValue: &proto.Expression_UnresolvedExtractValue{ + Child: child, + Extraction: extraction, + }, + } + return expr, nil +} + type unresolvedFunction struct { name string args []expression @@ -209,6 +240,10 @@ func (u *unresolvedFunction) ToProto(ctx context.Context) (*proto.Expression, er return expr, nil } +func NewUnresolvedExtractValue(name string, child expression, extraction expression) expression { + return &unresolvedExtractValue{name: name, child: child, extraction: extraction} +} + func NewUnresolvedFunction(name string, args []expression, isDistinct bool) expression { return &unresolvedFunction{name: name, args: args, isDistinct: isDistinct} }