From 53797bd05e5cfeca6751f339211b6b9dfc730485 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Sat, 20 Jul 2024 00:16:31 +0200 Subject: [PATCH] [SPARK-48951] Adding `column` and `functions` packages ### What changes were proposed in this pull request? This patch provides additional base capabilities that are needed to parallelize development more by adding very skeleton behavior for the `Column` abstraction in Spark. This allows the users to use the following APIs: ``` df, _ := spark.Sql("select * from range(100)") col, _ := df.Col("id") df, _ := df.Filter(col.Gt(functions.Lit(50)) df.Show(ctx, 100, false) ``` ### Why are the changes needed? Compatibility ### Does this PR introduce _any_ user-facing change? Adds the necessary public API for `Column` and `functions`. ### How was this patch tested? Added new tests. Closes #35 from grundprinzip/plans_and_exprs. Authored-by: Martin Grund Signed-off-by: Martin Grund --- CONTRIBUTING.md | 31 ++- .../main.go | 36 ++- quick-start.md | 6 +- spark/client/client.go | 5 +- spark/client/testutils/utils.go | 1 + spark/sparkerrors/errors.go | 1 + spark/sql/column/column.go | 65 +++++ spark/sql/column/column_test.go | 83 ++++++ spark/sql/column/expressions.go | 263 ++++++++++++++++++ spark/sql/column/expressions_test.go | 113 ++++++++ spark/sql/dataframe.go | 43 ++- spark/sql/dataframereader.go | 27 +- spark/sql/dataframereader_test.go | 2 +- spark/sql/functions/buiitins.go | 30 ++ spark/sql/plan.go | 42 ++- spark/sql/sparksession.go | 11 + spark/sql/sparksession_test.go | 13 + 17 files changed, 735 insertions(+), 37 deletions(-) create mode 100644 spark/sql/column/column.go create mode 100644 spark/sql/column/column_test.go create mode 100644 spark/sql/column/expressions.go create mode 100644 spark/sql/column/expressions_test.go create mode 100644 spark/sql/functions/buiitins.go diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 995f799..4f11da2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -29,4 +29,33 @@ This requires the following tools to be present in your PATH: 1. Java for checking license headers 2. `gofumpt` for formatting Go code -3. `golangci-lint` for linting Go code \ No newline at end of file +3. `golangci-lint` for linting Go code + + +### How to write tests + +Please make sure that you have proper testing for the new code your adding. As part of the +code base we started to add mocks that allow you to simulate a lot of the necessary API +and don't require a running Spark instance. + +`mock.ProtoClient` is a mock implementation of the `SparkConnectService_ExecutePlanClient` +interface which is the server-side stream of messages coming as a response from the server. + +`testutils.NewConnectServiceClientMock` will create a mock client that implements the +`SparkConnectServiceClient` interface. + +The combination of these two mocks allows you to test the client side of the code without +having to connect to Spark. + +### What to contribute + +We welcome contributions of all kinds to the `spark-connect-go` project. Some examples of +contributions are providing implementations of functionality that is missing in the Go +implementation. Some examples are, but are not limited to: + +* Adding an existing feature of the DataFrame API in Golang. +* Adding support for a builtin function in the Spark API in Golang. +* Improving error handling in the client. + +If you are unsure about whether a contribution is a good fit, feel free to open an issue +in the Apache Spark Jira. diff --git a/cmd/spark-connect-example-spark-session/main.go b/cmd/spark-connect-example-spark-session/main.go index 5e3a9a6..5f63bcc 100644 --- a/cmd/spark-connect-example-spark-session/main.go +++ b/cmd/spark-connect-example-spark-session/main.go @@ -21,6 +21,8 @@ import ( "flag" "log" + "github.com/apache/spark-connect-go/v35/spark/sql/functions" + "github.com/apache/spark-connect-go/v35/spark/sql" "github.com/apache/spark-connect-go/v35/spark/sql/utils" ) @@ -37,7 +39,39 @@ func main() { } defer utils.WarnOnError(spark.Stop, func(err error) {}) - df, err := spark.Sql(ctx, "select 'apple' as word, 123 as count union all select 'orange' as word, 456 as count") + //df, err := spark.Sql(ctx, "select * from range(100)") + //if err != nil { + // log.Fatalf("Failed: %s", err) + //} + // + //df, _ = df.FilterByString("id < 10") + //err = df.Show(ctx, 100, false) + //if err != nil { + // log.Fatalf("Failed: %s", err) + //} + // + //df, err = spark.Sql(ctx, "select * from range(100)") + //if err != nil { + // log.Fatalf("Failed: %s", err) + //} + // + //df, _ = df.Filter(functions.Col("id").Lt(functions.Expr("10"))) + //err = df.Show(ctx, 100, false) + //if err != nil { + // log.Fatalf("Failed: %s", err) + //} + + df, _ := spark.Sql(ctx, "select * from range(100)") + df, err = df.Filter(functions.Col("id").Lt(functions.Lit(20))) + if err != nil { + log.Fatalf("Failed: %s", err) + } + err = df.Show(ctx, 100, false) + if err != nil { + log.Fatalf("Failed: %s", err) + } + + df, err = spark.Sql(ctx, "select 'apple' as word, 123 as count union all select 'orange' as word, 456 as count") if err != nil { log.Fatalf("Failed: %s", err) } diff --git a/quick-start.md b/quick-start.md index 72bf139..c26140e 100644 --- a/quick-start.md +++ b/quick-start.md @@ -5,7 +5,7 @@ In your Go project `go.mod` file, add `spark-connect-go` library: ``` require ( - github.com/apache/spark-connect-go/v1 master + github.com/apache/spark-connect-go/v35 master ) ``` @@ -113,9 +113,9 @@ func main() { ## Start Spark Connect Server (Driver) -Download a Spark distribution (3.4.0+), unzip the folder, run command: +Download a Spark distribution (3.5.0+), unzip the folder, run command: ``` -sbin/start-connect-server.sh --packages org.apache.spark:spark-connect_2.12:3.4.0 +sbin/start-connect-server.sh --packages org.apache.spark:spark-connect_2.12:3.5.0 ``` ## Run Spark Connect Client Application diff --git a/spark/client/client.go b/spark/client/client.go index de321ab..ed65f44 100644 --- a/spark/client/client.go +++ b/spark/client/client.go @@ -200,10 +200,7 @@ func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) { case *proto.ExecutePlanResponse_ResultComplete_: c.done = true default: - fmt.Printf("Received unsupported response ") - //return nil, nil, &sparkerrors.UnsupportedResponseTypeError{ - // ResponseType: x, - //} + // Explicitly ignore messages that we cannot process at the moment. } } diff --git a/spark/client/testutils/utils.go b/spark/client/testutils/utils.go index 746bfdd..c0fc392 100644 --- a/spark/client/testutils/utils.go +++ b/spark/client/testutils/utils.go @@ -37,6 +37,7 @@ type connectServiceClient struct { func (c *connectServiceClient) ExecutePlan(ctx context.Context, in *proto.ExecutePlanRequest, opts ...grpc.CallOption) (proto.SparkConnectService_ExecutePlanClient, error) { if c.expectedExecutePlanRequest != nil { + // Check that the plans in both requests are identical assert.Equal(c.t, c.expectedExecutePlanRequest, in) } return c.executePlanClient, c.err diff --git a/spark/sparkerrors/errors.go b/spark/sparkerrors/errors.go index d537fc0..030db86 100644 --- a/spark/sparkerrors/errors.go +++ b/spark/sparkerrors/errors.go @@ -46,6 +46,7 @@ var ( ReadError = errorType(errors.New("read error")) ExecutionError = errorType(errors.New("execution error")) InvalidInputError = errorType(errors.New("invalid input")) + InvalidPlanError = errorType(errors.New("invalid plan")) ) type UnsupportedResponseTypeError struct { diff --git a/spark/sql/column/column.go b/spark/sql/column/column.go new file mode 100644 index 0000000..10b966f --- /dev/null +++ b/spark/sql/column/column.go @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package column + +import proto "github.com/apache/spark-connect-go/v35/internal/generated" + +type Column struct { + expr Expression +} + +func (c *Column) ToPlan() (*proto.Expression, error) { + return c.expr.ToPlan() +} + +func (c Column) Lt(other Column) Column { + return NewColumn(NewUnresolvedFunction("<", []Expression{c.expr, other.expr}, false)) +} + +func (c Column) Le(other Column) Column { + return NewColumn(NewUnresolvedFunction("<=", []Expression{c.expr, other.expr}, false)) +} + +func (c Column) Gt(other Column) Column { + return NewColumn(NewUnresolvedFunction(">", []Expression{c.expr, other.expr}, false)) +} + +func (c Column) Ge(other Column) Column { + return NewColumn(NewUnresolvedFunction(">=", []Expression{c.expr, other.expr}, false)) +} + +func (c Column) Eq(other Column) Column { + return NewColumn(NewUnresolvedFunction("==", []Expression{c.expr, other.expr}, false)) +} + +func (c Column) Neq(other Column) Column { + cmp := NewUnresolvedFunction("==", []Expression{c.expr, other.expr}, false) + return NewColumn(NewUnresolvedFunction("not", []Expression{cmp}, false)) +} + +func (c Column) Mul(other Column) Column { + return NewColumn(NewUnresolvedFunction("*", []Expression{c.expr, other.expr}, false)) +} + +func (c Column) Div(other Column) Column { + return NewColumn(NewUnresolvedFunction("/", []Expression{c.expr, other.expr}, false)) +} + +func NewColumn(expr Expression) Column { + return Column{ + expr: expr, + } +} diff --git a/spark/sql/column/column_test.go b/spark/sql/column/column_test.go new file mode 100644 index 0000000..7b82ca4 --- /dev/null +++ b/spark/sql/column/column_test.go @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package column + +import ( + "testing" + + proto "github.com/apache/spark-connect-go/v35/internal/generated" + "github.com/stretchr/testify/assert" +) + +func TestColumnFunctions(t *testing.T) { + col1 := NewColumn(NewColumnReference("col1")) + col2 := NewColumn(NewColumnReference("col2")) + + tests := []struct { + name string + arg Column + want *proto.Expression + }{ + { + name: "TestNewUnresolvedFunction", + arg: NewColumn(NewUnresolvedFunction("id", nil, false)), + want: &proto.Expression{ + ExprType: &proto.Expression_UnresolvedFunction_{ + UnresolvedFunction: &proto.Expression_UnresolvedFunction{ + FunctionName: "id", + IsDistinct: false, + }, + }, + }, + }, + { + name: "TestComparison", + arg: col1.Lt(col2), + want: &proto.Expression{ + ExprType: &proto.Expression_UnresolvedFunction_{ + UnresolvedFunction: &proto.Expression_UnresolvedFunction{ + FunctionName: "<", + IsDistinct: false, + Arguments: []*proto.Expression{ + { + ExprType: &proto.Expression_UnresolvedAttribute_{ + UnresolvedAttribute: &proto.Expression_UnresolvedAttribute{ + UnparsedIdentifier: "col1", + }, + }, + }, + { + ExprType: &proto.Expression_UnresolvedAttribute_{ + UnresolvedAttribute: &proto.Expression_UnresolvedAttribute{ + UnparsedIdentifier: "col2", + }, + }, + }, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.arg.ToPlan() + assert.NoError(t, err) + expected := tt.want + assert.Equalf(t, expected, got, "Input: %v", tt.arg.expr.DebugString()) + }) + } +} diff --git a/spark/sql/column/expressions.go b/spark/sql/column/expressions.go new file mode 100644 index 0000000..b818285 --- /dev/null +++ b/spark/sql/column/expressions.go @@ -0,0 +1,263 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package column + +import ( + "fmt" + "strings" + + "github.com/apache/spark-connect-go/v35/spark/sparkerrors" + + proto "github.com/apache/spark-connect-go/v35/internal/generated" +) + +func newProtoExpression() *proto.Expression { + return &proto.Expression{} +} + +// Expression is the interface for all expressions used by Spark Connect. +type Expression interface { + ToPlan() (*proto.Expression, error) + DebugString() string +} + +type caseWhenExpression struct { + branches []*caseWhenBranch + elseExpr Expression +} + +type caseWhenBranch struct { + condition Expression + value Expression +} + +func NewCaseWhenExpression(branches []*caseWhenBranch, elseExpr Expression) Expression { + return &caseWhenExpression{branches: branches, elseExpr: elseExpr} +} + +func (c *caseWhenExpression) DebugString() string { + branches := make([]string, 0) + for _, branch := range c.branches { + branches = append(branches, fmt.Sprintf("WHEN %s THEN %s", branch.condition.DebugString(), branch.value.DebugString())) + } + + elseExpr := "" + if c.elseExpr != nil { + elseExpr = fmt.Sprintf("ELSE %s", c.elseExpr.DebugString()) + } + + return fmt.Sprintf("CASE %s %s END", strings.Join(branches, " "), elseExpr) +} + +func (c *caseWhenExpression) ToPlan() (*proto.Expression, error) { + args := make([]Expression, 0) + for _, branch := range c.branches { + args = append(args, branch.condition) + args = append(args, branch.value) + } + + if c.elseExpr != nil { + args = append(args, c.elseExpr) + } + + fun := NewUnresolvedFunction("when", args, false) + return fun.ToPlan() +} + +type unresolvedFunction struct { + name string + args []Expression + isDistinct bool +} + +func (u *unresolvedFunction) DebugString() string { + args := make([]string, 0) + for _, arg := range u.args { + args = append(args, arg.DebugString()) + } + + distinct := "" + if u.isDistinct { + distinct = "DISTINCT " + } + + return fmt.Sprintf("%s(%s%s)", u.name, distinct, strings.Join(args, ", ")) +} + +func (u *unresolvedFunction) ToPlan() (*proto.Expression, error) { + // Convert input args to the proto Expression. + var args []*proto.Expression = nil + if len(u.args) > 0 { + args = make([]*proto.Expression, 0) + for _, arg := range u.args { + p, e := arg.ToPlan() + if e != nil { + return nil, e + } + args = append(args, p) + } + } + + expr := newProtoExpression() + expr.ExprType = &proto.Expression_UnresolvedFunction_{ + UnresolvedFunction: &proto.Expression_UnresolvedFunction{ + FunctionName: u.name, + Arguments: args, + }, + } + return expr, nil +} + +func NewUnresolvedFunction(name string, args []Expression, isDistinct bool) Expression { + return &unresolvedFunction{name: name, args: args, isDistinct: isDistinct} +} + +type columnAlias struct { + alias []string + expr Expression + metadata *string +} + +func NewColumnAlias(alias string, expr Expression) Expression { + return &columnAlias{alias: []string{alias}, expr: expr} +} + +func NewColumnAliasFromNameParts(alias []string, expr Expression) Expression { + return &columnAlias{alias: alias, expr: expr} +} + +func (c *columnAlias) DebugString() string { + child := c.expr.DebugString() + alias := strings.Join(c.alias, ".") + return fmt.Sprintf("%s AS %s", child, alias) +} + +func (c *columnAlias) ToPlan() (*proto.Expression, error) { + expr := newProtoExpression() + alias, err := c.expr.ToPlan() + if err != nil { + return nil, err + } + expr.ExprType = &proto.Expression_Alias_{ + Alias: &proto.Expression_Alias{ + Expr: alias, + Name: c.alias, + Metadata: c.metadata, + }, + } + return expr, nil +} + +type columnReference struct { + unparsedIdentifier string + planId *int64 +} + +func NewColumnReference(unparsedIdentifier string) Expression { + return &columnReference{unparsedIdentifier: unparsedIdentifier} +} + +func NewColumnReferenceWithPlanId(unparsedIdentifier string, planId int64) Expression { + return &columnReference{unparsedIdentifier: unparsedIdentifier, planId: &planId} +} + +func (c *columnReference) DebugString() string { + return c.unparsedIdentifier +} + +func (c *columnReference) ToPlan() (*proto.Expression, error) { + expr := newProtoExpression() + expr.ExprType = &proto.Expression_UnresolvedAttribute_{ + UnresolvedAttribute: &proto.Expression_UnresolvedAttribute{ + UnparsedIdentifier: c.unparsedIdentifier, + PlanId: c.planId, + }, + } + return expr, nil +} + +type sqlExression struct { + expression_string string +} + +func NewSQLExpression(expression string) Expression { + return &sqlExression{expression_string: expression} +} + +func (s *sqlExression) DebugString() string { + return s.expression_string +} + +func (s *sqlExression) ToPlan() (*proto.Expression, error) { + expr := newProtoExpression() + expr.ExprType = &proto.Expression_ExpressionString_{ + ExpressionString: &proto.Expression_ExpressionString{ + Expression: s.expression_string, + }, + } + return expr, nil +} + +type literalExpression struct { + value any +} + +func (l *literalExpression) DebugString() string { + return fmt.Sprintf("%v", l.value) +} + +func (l *literalExpression) ToPlan() (*proto.Expression, error) { + expr := newProtoExpression() + expr.ExprType = &proto.Expression_Literal_{ + Literal: &proto.Expression_Literal{}, + } + switch v := l.value.(type) { + case int8: + expr.GetLiteral().LiteralType = &proto.Expression_Literal_Byte{Byte: int32(v)} + case int16: + expr.GetLiteral().LiteralType = &proto.Expression_Literal_Short{Short: int32(v)} + case int32: + expr.GetLiteral().LiteralType = &proto.Expression_Literal_Integer{Integer: v} + case int64: + expr.GetLiteral().LiteralType = &proto.Expression_Literal_Long{Long: v} + case uint8: + expr.GetLiteral().LiteralType = &proto.Expression_Literal_Short{Short: int32(v)} + case uint16: + expr.GetLiteral().LiteralType = &proto.Expression_Literal_Integer{Integer: int32(v)} + case uint32: + expr.GetLiteral().LiteralType = &proto.Expression_Literal_Long{Long: int64(v)} + case float32: + expr.GetLiteral().LiteralType = &proto.Expression_Literal_Float{Float: v} + case float64: + expr.GetLiteral().LiteralType = &proto.Expression_Literal_Double{Double: v} + case string: + expr.GetLiteral().LiteralType = &proto.Expression_Literal_String_{String_: v} + case bool: + expr.GetLiteral().LiteralType = &proto.Expression_Literal_Boolean{Boolean: v} + case []byte: + expr.GetLiteral().LiteralType = &proto.Expression_Literal_Binary{Binary: v} + case int: + expr.GetLiteral().LiteralType = &proto.Expression_Literal_Long{Long: int64(v)} + default: + return nil, sparkerrors.WithType(sparkerrors.InvalidPlanError, + fmt.Errorf("unsupported literal type %T", v)) + } + return expr, nil +} + +func NewLiteral(value any) Expression { + return &literalExpression{value: value} +} diff --git a/spark/sql/column/expressions_test.go b/spark/sql/column/expressions_test.go new file mode 100644 index 0000000..10777c7 --- /dev/null +++ b/spark/sql/column/expressions_test.go @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package column + +import ( + "reflect" + "testing" + + proto "github.com/apache/spark-connect-go/v35/internal/generated" + "github.com/stretchr/testify/assert" +) + +func TestNewUnresolvedFunction(t *testing.T) { + colRef := NewColumnReference("martin") + colRefPlan, _ := colRef.ToPlan() + type args struct { + name string + arguments []Expression + isDistinct bool + } + tests := []struct { + name string + args args + want *proto.Expression + }{ + { + name: "TestNewUnresolvedFunction", + args: args{ + name: "id", + arguments: nil, + isDistinct: false, + }, + want: &proto.Expression{ + ExprType: &proto.Expression_UnresolvedFunction_{ + UnresolvedFunction: &proto.Expression_UnresolvedFunction{ + FunctionName: "id", + IsDistinct: false, + }, + }, + }, + }, + { + name: "TestNewUnresolvedWithArguments", + args: args{ + name: "id", + arguments: []Expression{colRef}, + isDistinct: false, + }, + want: &proto.Expression{ + ExprType: &proto.Expression_UnresolvedFunction_{ + UnresolvedFunction: &proto.Expression_UnresolvedFunction{ + FunctionName: "id", + IsDistinct: false, + Arguments: []*proto.Expression{ + colRefPlan, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewUnresolvedFunction(tt.args.name, tt.args.arguments, tt.args.isDistinct).ToPlan() + assert.NoError(t, err) + if !reflect.DeepEqual(got, tt.want) { + assert.Equal(t, tt.want, got) + t.Errorf("NewUnresolvedFunction() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewSQLExpression(t *testing.T) { + type args struct { + expression string + } + tests := []struct { + name string + args args + want *sqlExression + }{ + { + name: "TestNewSQLExpression", + args: args{ + expression: "id < 10", + }, + want: &sqlExression{ + expression_string: "id < 10", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NewSQLExpression(tt.args.expression); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewSQLExpression() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go index 273eaa6..f04e43d 100644 --- a/spark/sql/dataframe.go +++ b/spark/sql/dataframe.go @@ -20,6 +20,9 @@ import ( "context" "fmt" + "github.com/apache/spark-connect-go/v35/spark/sql/column" + "github.com/apache/spark-connect-go/v35/spark/sql/functions" + "github.com/apache/spark-connect-go/v35/spark/sql/types" proto "github.com/apache/spark-connect-go/v35/internal/generated" @@ -53,6 +56,10 @@ type DataFrame interface { Repartition(numPartitions int, columns []string) (DataFrame, error) // RepartitionByRange re-partitions a data frame by range partition. RepartitionByRange(numPartitions int, columns []RangePartitionColumn) (DataFrame, error) + + Filter(condition column.Column) (DataFrame, error) + FilterByString(condition string) (DataFrame, error) + Col(name string) (column.Column, error) } type RangePartitionColumn struct { @@ -261,12 +268,7 @@ func (df *dataFrameImpl) RepartitionByRange(numPartitions int, columns []RangePa func (df *dataFrameImpl) createPlan() *proto.Plan { return &proto.Plan{ OpType: &proto.Plan_Root{ - Root: &proto.Relation{ - Common: &proto.RelationCommon{ - PlanId: newPlanId(), - }, - RelType: df.relation.RelType, - }, + Root: df.relation, }, } } @@ -292,3 +294,32 @@ func (df *dataFrameImpl) repartitionByExpressions(numPartitions int, partitionEx } return NewDataFrame(df.session, newRelation), nil } + +func (df *dataFrameImpl) Filter(condition column.Column) (DataFrame, error) { + cnd, err := condition.ToPlan() + if err != nil { + return nil, err + } + + rel := &proto.Relation{ + Common: &proto.RelationCommon{ + PlanId: newPlanId(), + }, + RelType: &proto.Relation_Filter{ + Filter: &proto.Filter{ + Input: df.relation, + Condition: cnd, + }, + }, + } + return NewDataFrame(df.session, rel), nil +} + +func (df *dataFrameImpl) FilterByString(condition string) (DataFrame, error) { + return df.Filter(functions.Expr(condition)) +} + +func (df *dataFrameImpl) Col(name string) (column.Column, error) { + planId := df.relation.Common.GetPlanId() + return column.NewColumn(column.NewColumnReferenceWithPlanId(name, planId)), nil +} diff --git a/spark/sql/dataframereader.go b/spark/sql/dataframereader.go index cc1b8d3..5ca3a33 100644 --- a/spark/sql/dataframereader.go +++ b/spark/sql/dataframereader.go @@ -16,10 +16,6 @@ package sql -import ( - proto "github.com/apache/spark-connect-go/v35/internal/generated" -) - // DataFrameReader supports reading data from storage and returning a data frame. // TODO needs to implement other methods like Option(), Schema(), and also "strong typed" // reading (e.g. Parquet(), Orc(), Csv(), etc. @@ -28,6 +24,8 @@ type DataFrameReader interface { Format(source string) DataFrameReader // Load reads the underlying data and returns a data frame. Load(path string) (DataFrame, error) + // Reads a table from the underlying data source. + Table(name string) (DataFrame, error) } // dataFrameReaderImpl is an implementation of DataFrameReader interface. @@ -43,6 +41,10 @@ func NewDataframeReader(session *sparkSessionImpl) DataFrameReader { } } +func (w *dataFrameReaderImpl) Table(name string) (DataFrame, error) { + return NewDataFrame(w.sparkSession, newReadTableRelation(name)), nil +} + func (w *dataFrameReaderImpl) Format(source string) DataFrameReader { w.formatSource = source return w @@ -53,20 +55,5 @@ func (w *dataFrameReaderImpl) Load(path string) (DataFrame, error) { if w.formatSource != "" { format = w.formatSource } - return NewDataFrame(w.sparkSession, toRelation(path, format)), nil -} - -func toRelation(path, format string) *proto.Relation { - return &proto.Relation{ - RelType: &proto.Relation_Read{ - Read: &proto.Read{ - ReadType: &proto.Read_DataSource_{ - DataSource: &proto.Read_DataSource{ - Format: &format, - Paths: []string{path}, - }, - }, - }, - }, - } + return NewDataFrame(w.sparkSession, newReadWithFormatAndPath(path, format)), nil } diff --git a/spark/sql/dataframereader_test.go b/spark/sql/dataframereader_test.go index 572df1d..8b17794 100644 --- a/spark/sql/dataframereader_test.go +++ b/spark/sql/dataframereader_test.go @@ -34,7 +34,7 @@ func TestLoadCreatesADataFrame(t *testing.T) { func TestRelationContainsPathAndFormat(t *testing.T) { formatSource := "source" path := "path" - relation := toRelation(path, formatSource) + relation := newReadWithFormatAndPath(path, formatSource) assert.NotNil(t, relation) assert.Equal(t, &formatSource, relation.GetRead().GetDataSource().Format) assert.Equal(t, path, relation.GetRead().GetDataSource().Paths[0]) diff --git a/spark/sql/functions/buiitins.go b/spark/sql/functions/buiitins.go new file mode 100644 index 0000000..ed1b3d1 --- /dev/null +++ b/spark/sql/functions/buiitins.go @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package functions + +import "github.com/apache/spark-connect-go/v35/spark/sql/column" + +func Expr(expr string) column.Column { + return column.NewColumn(column.NewSQLExpression(expr)) +} + +func Col(name string) column.Column { + return column.NewColumn(column.NewColumnReference(name)) +} + +func Lit(value any) column.Column { + return column.NewColumn(column.NewLiteral(value)) +} diff --git a/spark/sql/plan.go b/spark/sql/plan.go index 66b9e05..8ecacab 100644 --- a/spark/sql/plan.go +++ b/spark/sql/plan.go @@ -16,7 +16,11 @@ package sql -import "sync/atomic" +import ( + "sync/atomic" + + proto "github.com/apache/spark-connect-go/v35/internal/generated" +) var atomicInt64 atomic.Int64 @@ -24,3 +28,39 @@ func newPlanId() *int64 { v := atomicInt64.Add(1) return &v } + +func resetPlanIdForTesting() { + atomicInt64.Swap(0) +} + +func newReadTableRelation(table string) *proto.Relation { + return &proto.Relation{ + Common: &proto.RelationCommon{ + PlanId: newPlanId(), + }, + RelType: &proto.Relation_Read{ + Read: &proto.Read{ + ReadType: &proto.Read_NamedTable_{ + NamedTable: &proto.Read_NamedTable{ + UnparsedIdentifier: table, + }, + }, + }, + }, + } +} + +func newReadWithFormatAndPath(path, format string) *proto.Relation { + return &proto.Relation{ + RelType: &proto.Relation_Read{ + Read: &proto.Read{ + ReadType: &proto.Read_DataSource_{ + DataSource: &proto.Read_DataSource{ + Format: &format, + Paths: []string{path}, + }, + }, + }, + }, + } +} diff --git a/spark/sql/sparksession.go b/spark/sql/sparksession.go index 78461a3..d3b61b2 100644 --- a/spark/sql/sparksession.go +++ b/spark/sql/sparksession.go @@ -32,6 +32,7 @@ type SparkSession interface { Read() DataFrameReader Sql(ctx context.Context, query string) (DataFrame, error) Stop() error + Table(name string) (DataFrame, error) } // NewSessionBuilder creates a new session builder for starting a new spark session @@ -116,6 +117,9 @@ func (s *sparkSessionImpl) Sql(ctx context.Context, query string) (DataFrame, er val, ok := properties["sql_command_result"] if !ok { plan := &proto.Relation{ + Common: &proto.RelationCommon{ + PlanId: newPlanId(), + }, RelType: &proto.Relation_Sql{ Sql: &proto.SQL{ Query: query, @@ -125,6 +129,9 @@ func (s *sparkSessionImpl) Sql(ctx context.Context, query string) (DataFrame, er return NewDataFrame(s, plan), nil } else { rel := val.(*proto.Relation) + rel.Common = &proto.RelationCommon{ + PlanId: newPlanId(), + } return NewDataFrame(s, rel), nil } } @@ -132,3 +139,7 @@ func (s *sparkSessionImpl) Sql(ctx context.Context, query string) (DataFrame, er func (s *sparkSessionImpl) Stop() error { return nil } + +func (s *sparkSessionImpl) Table(name string) (DataFrame, error) { + return s.Read().Table(name) +} diff --git a/spark/sql/sparksession_test.go b/spark/sql/sparksession_test.go index 775dce6..84d4724 100644 --- a/spark/sql/sparksession_test.go +++ b/spark/sql/sparksession_test.go @@ -35,6 +35,19 @@ import ( "github.com/stretchr/testify/require" ) +func TestSparkSessionTable(t *testing.T) { + resetPlanIdForTesting() + plan := newReadTableRelation("table") + resetPlanIdForTesting() + s := testutils.NewConnectServiceClientMock(nil, nil, nil, nil, t) + c := client.NewSparkExecutorFromClient(s, nil, "") + session := &sparkSessionImpl{client: c} + df, err := session.Table("table") + df_plan := df.(*dataFrameImpl).relation + assert.Equal(t, plan, df_plan) + assert.NoError(t, err) +} + func TestSQLCallsExecutePlanWithSQLOnClient(t *testing.T) { ctx := context.Background()