From 2e0fe708d0f1028d8f61674485516c6266d282ae Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Sun, 7 Jul 2024 00:19:37 +0200 Subject: [PATCH 1/7] [WIP] Client Refactoring --- README.md | 34 +--- .../main.go | 4 +- spark/client/client.go | 128 ++++++++++++++ spark/client/client_test.go | 58 +++++++ spark/client/testutils/utils.go | 81 +++++++++ spark/sql/dataframe.go | 24 +-- spark/sql/dataframereader.go | 8 +- spark/sql/dataframewriter.go | 13 +- spark/sql/dataframewriter_test.go | 23 ++- spark/sql/mocks_test.go | 6 +- spark/sql/{session => }/sparksession.go | 70 ++------ spark/sql/sparksession_test.go | 156 ++++++++++++++++++ 12 files changed, 483 insertions(+), 122 deletions(-) create mode 100644 spark/client/client.go create mode 100644 spark/client/client_test.go create mode 100644 spark/client/testutils/utils.go rename spark/sql/{session => }/sparksession.go (63%) create mode 100644 spark/sql/sparksession_test.go diff --git a/README.md b/README.md index 7832edb..234191f 100644 --- a/README.md +++ b/README.md @@ -50,37 +50,9 @@ See [Quick Start Guide](quick-start.md) ## High Level Design -Following [diagram](https://textik.com/#ac299c8f32c4c342) shows main code in current prototype: - -``` - +-------------------+ - | | - | dataFrameImpl | - | | - +-------------------+ - | - | - + - +-------------------+ - | | - | sparkSessionImpl | - | | - +-------------------+ - | - | - + -+---------------------------+ +----------------+ -| | | | -| SparkConnectServiceClient |--------------+| Spark Driver | -| | | | -+---------------------------+ +----------------+ -``` - -`SparkConnectServiceClient` is GRPC client which talks to Spark Driver. `sparkSessionImpl` generates `dataFrameImpl` -instances. `dataFrameImpl` uses the GRPC client in `sparkSessionImpl` to communicate with Spark Driver. - -We will mimic the logic in Spark Connect Scala implementation, and adopt Go common practices, e.g. returning `error` object for -error handling. +The overall goal of the design is to find a good balance of principle of the least surprise for +develoeprs that are familiar with the APIs of Apache Spark and idiomatic Go usage. The high-level +structure of the packages follows roughly the PySpark giudance but with Go idioms. ## Contributing diff --git a/cmd/spark-connect-example-spark-session/main.go b/cmd/spark-connect-example-spark-session/main.go index a9c17b0..c4c1ea6 100644 --- a/cmd/spark-connect-example-spark-session/main.go +++ b/cmd/spark-connect-example-spark-session/main.go @@ -21,8 +21,6 @@ import ( "flag" "log" - "github.com/apache/spark-connect-go/v35/spark/sql/session" - "github.com/apache/spark-connect-go/v35/spark/sql" ) @@ -32,7 +30,7 @@ var remote = flag.String("remote", "sc://localhost:15002", func main() { flag.Parse() ctx := context.Background() - spark, err := session.NewSessionBuilder().Remote(*remote).Build(ctx) + spark, err := sql.NewSessionBuilder().Remote(*remote).Build(ctx) if err != nil { log.Fatalf("Failed: %s", err) } diff --git a/spark/client/client.go b/spark/client/client.go new file mode 100644 index 0000000..270bf10 --- /dev/null +++ b/spark/client/client.go @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "context" + "errors" + "fmt" + "io" + + "github.com/apache/spark-connect-go/v35/internal/generated" + proto "github.com/apache/spark-connect-go/v35/internal/generated" + "github.com/apache/spark-connect-go/v35/spark/sparkerrors" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +// SparkExecutor is the interface for executing a plan in Spark. +// +// This interface does not deal with the public Spark API abstractions but roughly deals on the +// RPC API level and the necessary translation of Arrow to Row objects. +type SparkExecutor interface { + ExecutePlan(ctx context.Context, plan *generated.Plan) (*ExecutePlanClient, error) + AnalyzePlan(ctx context.Context, plan *generated.Plan) (*generated.AnalyzePlanResponse, error) +} + +type SparkExecutorImpl struct { + client proto.SparkConnectServiceClient + metadata metadata.MD + sessionId string +} + +func (s *SparkExecutorImpl) ExecutePlan(ctx context.Context, plan *proto.Plan) (*ExecutePlanClient, error) { + request := proto.ExecutePlanRequest{ + SessionId: s.sessionId, + Plan: plan, + UserContext: &proto.UserContext{ + UserId: "na", + }, + } + + // Append the other items to the request. + ctx = metadata.NewOutgoingContext(ctx, s.metadata) + c, err := s.client.ExecutePlan(ctx, &request) + if err != nil { + return nil, sparkerrors.WithType(fmt.Errorf("failed to call ExecutePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError) + } + return NewExecutePlanClient(c), nil +} + +func (s *SparkExecutorImpl) AnalyzePlan(ctx context.Context, plan *proto.Plan) (*proto.AnalyzePlanResponse, error) { + request := proto.AnalyzePlanRequest{ + SessionId: s.sessionId, + Analyze: &proto.AnalyzePlanRequest_Schema_{ + Schema: &proto.AnalyzePlanRequest_Schema{ + Plan: plan, + }, + }, + UserContext: &proto.UserContext{ + UserId: "na", + }, + } + // Append the other items to the request. + ctx = metadata.NewOutgoingContext(ctx, s.metadata) + + response, err := s.client.AnalyzePlan(ctx, &request) + if err != nil { + return nil, sparkerrors.WithType(fmt.Errorf("failed to call AnalyzePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError) + } + return response, nil +} + +func NewSparkExecutor(conn *grpc.ClientConn, md metadata.MD, sessionId string) SparkExecutor { + client := proto.NewSparkConnectServiceClient(conn) + return &SparkExecutorImpl{ + client: client, + metadata: md, + sessionId: sessionId, + } +} + +func NewSparkExecutorFromClient(client proto.SparkConnectServiceClient, md metadata.MD, sessionId string) SparkExecutor { + return &SparkExecutorImpl{ + client: client, + metadata: md, + sessionId: sessionId, + } +} + +type ExecutePlanClient struct { + generated.SparkConnectService_ExecutePlanClient +} + +func NewExecutePlanClient(responseClient proto.SparkConnectService_ExecutePlanClient) *ExecutePlanClient { + return &ExecutePlanClient{ + responseClient, + } +} + +// consumeAll reads through the returned GRPC stream from Spark Connect Driver. It will +// discard the returned data if there is no error. This is necessary for handling GRPC response for +// saving data frame, since such consuming will trigger Spark Connect Driver really saving data frame. +// If we do not consume the returned GRPC stream, Spark Connect Driver will not really save data frame. +func (c *ExecutePlanClient) ConsumeAll() error { + for { + _, err := c.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } else { + return sparkerrors.WithType(fmt.Errorf("failed to receive plan execution response: %w", err), sparkerrors.ReadError) + } + } + } +} diff --git a/spark/client/client_test.go b/spark/client/client_test.go new file mode 100644 index 0000000..e601db8 --- /dev/null +++ b/spark/client/client_test.go @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client_test + +import ( + "context" + "testing" + + proto "github.com/apache/spark-connect-go/v35/internal/generated" + "github.com/apache/spark-connect-go/v35/spark/client" + "github.com/apache/spark-connect-go/v35/spark/client/testutils" + "github.com/stretchr/testify/assert" +) + +func TestAnalyzePlanCallsAnalyzePlanOnClient(t *testing.T) { + ctx := context.Background() + response := &proto.AnalyzePlanResponse{} + c := client.NewSparkExecutorFromClient(testutils.NewConnectServiceClientMock(nil, nil, response, nil, nil), nil, "") + resp, err := c.AnalyzePlan(ctx, &proto.Plan{}) + assert.NoError(t, err) + assert.NotNil(t, resp) +} + +func TestAnalyzePlanFailsIfClientFails(t *testing.T) { + ctx := context.Background() + c := client.NewSparkExecutorFromClient(testutils.NewConnectServiceClientMock(nil, nil, nil, assert.AnError, nil), nil, "") + resp, err := c.AnalyzePlan(ctx, &proto.Plan{}) + assert.Nil(t, resp) + assert.Error(t, err) +} + +func TestExecutePlanCallsExecutePlanOnClient(t *testing.T) { + ctx := context.Background() + plan := &proto.Plan{} + request := &proto.ExecutePlanRequest{ + Plan: plan, + UserContext: &proto.UserContext{ + UserId: "na", + }, + } + c := client.NewSparkExecutorFromClient(testutils.NewConnectServiceClientMock(request, &client.ExecutePlanClient{}, nil, nil, t), nil, "") + resp, err := c.ExecutePlan(ctx, plan) + assert.NoError(t, err) + assert.NotNil(t, resp) +} diff --git a/spark/client/testutils/utils.go b/spark/client/testutils/utils.go new file mode 100644 index 0000000..a0313d1 --- /dev/null +++ b/spark/client/testutils/utils.go @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutils + +import ( + "context" + "testing" + + proto "github.com/apache/spark-connect-go/v35/internal/generated" + "github.com/apache/spark-connect-go/v35/spark/client" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" +) + +type connectServiceClient struct { + t *testing.T + + analysePlanResponse *proto.AnalyzePlanResponse + executePlanClient *client.ExecutePlanClient + expectedExecutePlanRequest *proto.ExecutePlanRequest + + err error +} + +func (c *connectServiceClient) ExecutePlan(ctx context.Context, in *proto.ExecutePlanRequest, opts ...grpc.CallOption) (proto.SparkConnectService_ExecutePlanClient, error) { + if c.expectedExecutePlanRequest != nil { + assert.Equal(c.t, c.expectedExecutePlanRequest, in) + } + return c.executePlanClient, c.err +} + +func (c *connectServiceClient) AnalyzePlan(ctx context.Context, in *proto.AnalyzePlanRequest, opts ...grpc.CallOption) (*proto.AnalyzePlanResponse, error) { + return c.analysePlanResponse, c.err +} + +func (c *connectServiceClient) Config(ctx context.Context, in *proto.ConfigRequest, opts ...grpc.CallOption) (*proto.ConfigResponse, error) { + return nil, c.err +} + +func (c *connectServiceClient) AddArtifacts(ctx context.Context, opts ...grpc.CallOption) (proto.SparkConnectService_AddArtifactsClient, error) { + return nil, c.err +} + +func (c *connectServiceClient) ArtifactStatus(ctx context.Context, in *proto.ArtifactStatusesRequest, opts ...grpc.CallOption) (*proto.ArtifactStatusesResponse, error) { + return nil, c.err +} + +func (c *connectServiceClient) Interrupt(ctx context.Context, in *proto.InterruptRequest, opts ...grpc.CallOption) (*proto.InterruptResponse, error) { + return nil, c.err +} + +func (c *connectServiceClient) ReattachExecute(ctx context.Context, in *proto.ReattachExecuteRequest, opts ...grpc.CallOption) (proto.SparkConnectService_ReattachExecuteClient, error) { + return nil, c.err +} + +func (c *connectServiceClient) ReleaseExecute(ctx context.Context, in *proto.ReleaseExecuteRequest, opts ...grpc.CallOption) (*proto.ReleaseExecuteResponse, error) { + return nil, c.err +} + +func NewConnectServiceClientMock(epr *proto.ExecutePlanRequest, epc *client.ExecutePlanClient, apc *proto.AnalyzePlanResponse, err error, t *testing.T) proto.SparkConnectServiceClient { + return &connectServiceClient{ + t: t, + expectedExecutePlanRequest: epr, + analysePlanResponse: apc, + executePlanClient: epc, + err: err, + } +} diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go index ac7473e..fb0a8af 100644 --- a/spark/sql/dataframe.go +++ b/spark/sql/dataframe.go @@ -66,15 +66,15 @@ type RangePartitionColumn struct { // dataFrameImpl is an implementation of DataFrame interface. type dataFrameImpl struct { - sparkExecutor SparkExecutor - relation *proto.Relation // TODO change to proto.Plan? + session *sparkSessionImpl + relation *proto.Relation // TODO change to proto.Plan? } // NewDataFrame creates a new DataFrame -func NewDataFrame(sparkExecutor SparkExecutor, relation *proto.Relation) DataFrame { +func NewDataFrame(session *sparkSessionImpl, relation *proto.Relation) DataFrame { return &dataFrameImpl{ - sparkExecutor: sparkExecutor, - relation: relation, + session: session, + relation: relation, } } @@ -114,7 +114,7 @@ func (df *dataFrameImpl) WriteResult(ctx context.Context, collector ResultCollec }, } - responseClient, err := df.sparkExecutor.ExecutePlan(ctx, plan) + responseClient, err := df.session.client.ExecutePlan(ctx, plan) if err != nil { return sparkerrors.WithType(fmt.Errorf("failed to show dataframe: %w", err), sparkerrors.ExecutionError) } @@ -137,7 +137,7 @@ func (df *dataFrameImpl) WriteResult(ctx context.Context, collector ResultCollec } func (df *dataFrameImpl) Schema(ctx context.Context) (*StructType, error) { - response, err := df.sparkExecutor.AnalyzePlan(ctx, df.createPlan()) + response, err := df.session.client.AnalyzePlan(ctx, df.createPlan()) if err != nil { return nil, sparkerrors.WithType(fmt.Errorf("failed to analyze plan: %w", err), sparkerrors.ExecutionError) } @@ -147,7 +147,7 @@ func (df *dataFrameImpl) Schema(ctx context.Context) (*StructType, error) { } func (df *dataFrameImpl) Collect(ctx context.Context) ([]Row, error) { - responseClient, err := df.sparkExecutor.ExecutePlan(ctx, df.createPlan()) + responseClient, err := df.session.client.ExecutePlan(ctx, df.createPlan()) if err != nil { return nil, sparkerrors.WithType(fmt.Errorf("failed to execute plan: %w", err), sparkerrors.ExecutionError) } @@ -196,7 +196,7 @@ func (df *dataFrameImpl) Write() DataFrameWriter { } func (df *dataFrameImpl) Writer() DataFrameWriter { - return newDataFrameWriter(df.sparkExecutor, df.relation) + return newDataFrameWriter(df.session, df.relation) } func (df *dataFrameImpl) CreateTempView(ctx context.Context, viewName string, replace bool, global bool) error { @@ -215,12 +215,12 @@ func (df *dataFrameImpl) CreateTempView(ctx context.Context, viewName string, re }, } - responseClient, err := df.sparkExecutor.ExecutePlan(ctx, plan) + responseClient, err := df.session.client.ExecutePlan(ctx, plan) if err != nil { return sparkerrors.WithType(fmt.Errorf("failed to create temp view %s: %w", viewName, err), sparkerrors.ExecutionError) } - return responseClient.consumeAll() + return responseClient.ConsumeAll() } func (df *dataFrameImpl) Repartition(numPartitions int, columns []string) (DataFrame, error) { @@ -303,7 +303,7 @@ func (df *dataFrameImpl) repartitionByExpressions(numPartitions int, partitionEx }, }, } - return NewDataFrame(df.sparkExecutor, newRelation), nil + return NewDataFrame(df.session, newRelation), nil } func showArrowBatch(arrowBatch *proto.ExecutePlanResponse_ArrowBatch, collector ResultCollector) error { diff --git a/spark/sql/dataframereader.go b/spark/sql/dataframereader.go index 17cce98..732ec7c 100644 --- a/spark/sql/dataframereader.go +++ b/spark/sql/dataframereader.go @@ -1,6 +1,8 @@ package sql -import proto "github.com/apache/spark-connect-go/v35/internal/generated" +import ( + proto "github.com/apache/spark-connect-go/v35/internal/generated" +) // DataFrameReader supports reading data from storage and returning a data frame. // TODO needs to implement other methods like Option(), Schema(), and also "strong typed" @@ -14,12 +16,12 @@ type DataFrameReader interface { // dataFrameReaderImpl is an implementation of DataFrameReader interface. type dataFrameReaderImpl struct { - sparkSession SparkExecutor + sparkSession *sparkSessionImpl formatSource string } // NewDataframeReader creates a new DataFrameReader -func NewDataframeReader(session SparkExecutor) DataFrameReader { +func NewDataframeReader(session *sparkSessionImpl) DataFrameReader { return &dataFrameReaderImpl{ sparkSession: session, } diff --git a/spark/sql/dataframewriter.go b/spark/sql/dataframewriter.go index 4c99788..5dc305c 100644 --- a/spark/sql/dataframewriter.go +++ b/spark/sql/dataframewriter.go @@ -19,12 +19,7 @@ type DataFrameWriter interface { Save(ctx context.Context, path string) error } -type SparkExecutor interface { - ExecutePlan(ctx context.Context, plan *proto.Plan) (*ExecutePlanClient, error) - AnalyzePlan(ctx context.Context, plan *proto.Plan) (*proto.AnalyzePlanResponse, error) -} - -func newDataFrameWriter(sparkExecutor SparkExecutor, relation *proto.Relation) DataFrameWriter { +func newDataFrameWriter(sparkExecutor *sparkSessionImpl, relation *proto.Relation) DataFrameWriter { return &dataFrameWriterImpl{ sparkExecutor: sparkExecutor, relation: relation, @@ -33,7 +28,7 @@ func newDataFrameWriter(sparkExecutor SparkExecutor, relation *proto.Relation) D // dataFrameWriterImpl is an implementation of DataFrameWriter interface. type dataFrameWriterImpl struct { - sparkExecutor SparkExecutor + sparkExecutor *sparkSessionImpl relation *proto.Relation saveMode string formatSource string @@ -74,12 +69,12 @@ func (w *dataFrameWriterImpl) Save(ctx context.Context, path string) error { }, }, } - responseClient, err := w.sparkExecutor.ExecutePlan(ctx, plan) + responseClient, err := w.sparkExecutor.client.ExecutePlan(ctx, plan) if err != nil { return err } - return responseClient.consumeAll() + return responseClient.ConsumeAll() } func getSaveMode(mode string) (proto.WriteOperation_SaveMode, error) { diff --git a/spark/sql/dataframewriter_test.go b/spark/sql/dataframewriter_test.go index e886f6d..3c20df0 100644 --- a/spark/sql/dataframewriter_test.go +++ b/spark/sql/dataframewriter_test.go @@ -5,6 +5,8 @@ import ( "io" "testing" + "github.com/apache/spark-connect-go/v35/spark/client" + proto "github.com/apache/spark-connect-go/v35/internal/generated" "github.com/apache/spark-connect-go/v35/spark/mocks" "github.com/stretchr/testify/assert" @@ -39,14 +41,17 @@ func TestGetSaveMode(t *testing.T) { func TestSaveExecutesWriteOperationUntilEOF(t *testing.T) { relation := &proto.Relation{} executor := &testExecutor{ - client: NewExecutePlanClient(&mocks.ProtoClient{ + client: client.NewExecutePlanClient(&mocks.ProtoClient{ Err: io.EOF, }), } + session := &sparkSessionImpl{ + client: executor, + } ctx := context.Background() path := "path" - writer := newDataFrameWriter(executor, relation) + writer := newDataFrameWriter(session, relation) writer.Format("format") writer.Mode("append") err := writer.Save(ctx, path) @@ -56,14 +61,17 @@ func TestSaveExecutesWriteOperationUntilEOF(t *testing.T) { func TestSaveFailsIfAnotherErrorHappensWhenReadingStream(t *testing.T) { relation := &proto.Relation{} executor := &testExecutor{ - client: NewExecutePlanClient(&mocks.ProtoClient{ + client: client.NewExecutePlanClient(&mocks.ProtoClient{ Err: assert.AnError, }), } + session := &sparkSessionImpl{ + client: executor, + } ctx := context.Background() path := "path" - writer := newDataFrameWriter(executor, relation) + writer := newDataFrameWriter(session, relation) writer.Format("format") writer.Mode("append") err := writer.Save(ctx, path) @@ -73,13 +81,16 @@ func TestSaveFailsIfAnotherErrorHappensWhenReadingStream(t *testing.T) { func TestSaveFailsIfAnotherErrorHappensWhenExecuting(t *testing.T) { relation := &proto.Relation{} executor := &testExecutor{ - client: NewExecutePlanClient(&mocks.ProtoClient{}), + client: client.NewExecutePlanClient(&mocks.ProtoClient{}), err: assert.AnError, } + session := &sparkSessionImpl{ + client: executor, + } ctx := context.Background() path := "path" - writer := newDataFrameWriter(executor, relation) + writer := newDataFrameWriter(session, relation) writer.Format("format") writer.Mode("append") err := writer.Save(ctx, path) diff --git a/spark/sql/mocks_test.go b/spark/sql/mocks_test.go index ded00e0..ca80b3a 100644 --- a/spark/sql/mocks_test.go +++ b/spark/sql/mocks_test.go @@ -3,16 +3,18 @@ package sql import ( "context" + client2 "github.com/apache/spark-connect-go/v35/spark/client" + proto "github.com/apache/spark-connect-go/v35/internal/generated" ) type testExecutor struct { - client *ExecutePlanClient + client *client2.ExecutePlanClient response *proto.AnalyzePlanResponse err error } -func (t *testExecutor) ExecutePlan(ctx context.Context, plan *proto.Plan) (*ExecutePlanClient, error) { +func (t *testExecutor) ExecutePlan(ctx context.Context, plan *proto.Plan) (*client2.ExecutePlanClient, error) { if t.err != nil { return nil, t.err } diff --git a/spark/sql/session/sparksession.go b/spark/sql/sparksession.go similarity index 63% rename from spark/sql/session/sparksession.go rename to spark/sql/sparksession.go index 8a45fb0..d6daa7a 100644 --- a/spark/sql/session/sparksession.go +++ b/spark/sql/sparksession.go @@ -14,24 +14,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -package session +package sql import ( "context" "fmt" - "github.com/apache/spark-connect-go/v35/spark/client/channel" - proto "github.com/apache/spark-connect-go/v35/internal/generated" + "github.com/apache/spark-connect-go/v35/spark/client" + "github.com/apache/spark-connect-go/v35/spark/client/channel" "github.com/apache/spark-connect-go/v35/spark/sparkerrors" - "github.com/apache/spark-connect-go/v35/spark/sql" "github.com/google/uuid" "google.golang.org/grpc/metadata" ) type SparkSession interface { - Read() sql.DataFrameReader - Sql(ctx context.Context, query string) (sql.DataFrame, error) + Read() DataFrameReader + Sql(ctx context.Context, query string) (DataFrame, error) Stop() error } @@ -75,25 +74,23 @@ func (s *SparkSessionBuilder) Build(ctx context.Context) (SparkSession, error) { meta[k] = append(meta[k], v) } - client := proto.NewSparkConnectServiceClient(conn) + sessionId := uuid.NewString() return &sparkSessionImpl{ - sessionId: uuid.NewString(), - client: client, - metadata: meta, + sessionId: sessionId, + client: client.NewSparkExecutor(conn, meta, sessionId), }, nil } type sparkSessionImpl struct { sessionId string - client proto.SparkConnectServiceClient - metadata metadata.MD + client client.SparkExecutor } -func (s *sparkSessionImpl) Read() sql.DataFrameReader { - return sql.NewDataframeReader(s) +func (s *sparkSessionImpl) Read() DataFrameReader { + return NewDataframeReader(s) } -func (s *sparkSessionImpl) Sql(ctx context.Context, query string) (sql.DataFrame, error) { +func (s *sparkSessionImpl) Sql(ctx context.Context, query string) (DataFrame, error) { plan := &proto.Plan{ OpType: &proto.Plan_Command{ Command: &proto.Command{ @@ -105,7 +102,7 @@ func (s *sparkSessionImpl) Sql(ctx context.Context, query string) (sql.DataFrame }, }, } - responseClient, err := s.ExecutePlan(ctx, plan) + responseClient, err := s.client.ExecutePlan(ctx, plan) if err != nil { return nil, sparkerrors.WithType(fmt.Errorf("failed to execute sql: %s: %w", query, err), sparkerrors.ExecutionError) } @@ -118,49 +115,10 @@ func (s *sparkSessionImpl) Sql(ctx context.Context, query string) (sql.DataFrame if sqlCommandResult == nil { continue } - return sql.NewDataFrame(s, sqlCommandResult.GetRelation()), nil + return NewDataFrame(s, sqlCommandResult.GetRelation()), nil } } func (s *sparkSessionImpl) Stop() error { return nil } - -func (s *sparkSessionImpl) ExecutePlan(ctx context.Context, plan *proto.Plan) (*sql.ExecutePlanClient, error) { - request := proto.ExecutePlanRequest{ - SessionId: s.sessionId, - Plan: plan, - UserContext: &proto.UserContext{ - UserId: "na", - }, - } - // Append the other items to the request. - ctx = metadata.NewOutgoingContext(ctx, s.metadata) - client, err := s.client.ExecutePlan(ctx, &request) - if err != nil { - return nil, sparkerrors.WithType(fmt.Errorf("failed to call ExecutePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError) - } - return sql.NewExecutePlanClient(client), nil -} - -func (s *sparkSessionImpl) AnalyzePlan(ctx context.Context, plan *proto.Plan) (*proto.AnalyzePlanResponse, error) { - request := proto.AnalyzePlanRequest{ - SessionId: s.sessionId, - Analyze: &proto.AnalyzePlanRequest_Schema_{ - Schema: &proto.AnalyzePlanRequest_Schema{ - Plan: plan, - }, - }, - UserContext: &proto.UserContext{ - UserId: "na", - }, - } - // Append the other items to the request. - ctx = metadata.NewOutgoingContext(ctx, s.metadata) - - response, err := s.client.AnalyzePlan(ctx, &request) - if err != nil { - return nil, sparkerrors.WithType(fmt.Errorf("failed to call AnalyzePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError) - } - return response, nil -} diff --git a/spark/sql/sparksession_test.go b/spark/sql/sparksession_test.go new file mode 100644 index 0000000..e02d8bd --- /dev/null +++ b/spark/sql/sparksession_test.go @@ -0,0 +1,156 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "bytes" + "context" + "testing" + + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/arrow/go/v12/arrow/ipc" + "github.com/apache/arrow/go/v12/arrow/memory" + proto "github.com/apache/spark-connect-go/v35/internal/generated" + "github.com/apache/spark-connect-go/v35/spark/client" + "github.com/apache/spark-connect-go/v35/spark/client/testutils" + "github.com/apache/spark-connect-go/v35/spark/mocks" + "github.com/apache/spark-connect-go/v35/spark/sparkerrors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSQLCallsExecutePlanWithSQLOnClient(t *testing.T) { + ctx := context.Background() + + query := "select * from bla" + plan := &proto.Plan{ + OpType: &proto.Plan_Command{ + Command: &proto.Command{ + CommandType: &proto.Command_SqlCommand{ + SqlCommand: &proto.SqlCommand{ + Sql: query, + }, + }, + }, + }, + } + request := &proto.ExecutePlanRequest{ + Plan: plan, + UserContext: &proto.UserContext{ + UserId: "na", + }, + } + + s := testutils.NewConnectServiceClientMock(request, &client.ExecutePlanClient{ + SparkConnectService_ExecutePlanClient: &mocks.ProtoClient{ + RecvResponse: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{ + SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{}, + }, + }, + }, + }, nil, nil, t) + c := client.NewSparkExecutorFromClient(s, nil, "") + + session := &sparkSessionImpl{ + client: c, + } + resp, err := session.Sql(ctx, query) + assert.NoError(t, err) + assert.NotNil(t, resp) +} + +func TestNewSessionBuilderCreatesASession(t *testing.T) { + ctx := context.Background() + spark, err := NewSessionBuilder().Remote("sc://connection").Build(ctx) + assert.NoError(t, err) + assert.NotNil(t, spark) +} + +func TestNewSessionBuilderFailsIfConnectionStringIsInvalid(t *testing.T) { + ctx := context.Background() + spark, err := NewSessionBuilder().Remote("invalid").Build(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, sparkerrors.InvalidInputError) + assert.Nil(t, spark) +} + +func TestWriteResultStreamsArrowResultToCollector(t *testing.T) { + ctx := context.Background() + + arrowFields := []arrow.Field{ + { + Name: "show_string", + Type: &arrow.StringType{}, + }, + } + arrowSchema := arrow.NewSchema(arrowFields, nil) + var buf bytes.Buffer + arrowWriter := ipc.NewWriter(&buf, ipc.WithSchema(arrowSchema)) + defer arrowWriter.Close() + + alloc := memory.NewGoAllocator() + recordBuilder := array.NewRecordBuilder(alloc, arrowSchema) + defer recordBuilder.Release() + + recordBuilder.Field(0).(*array.StringBuilder).Append("str1a\nstr1b") + recordBuilder.Field(0).(*array.StringBuilder).Append("str2") + + record := recordBuilder.NewRecord() + defer record.Release() + + err := arrowWriter.Write(record) + require.Nil(t, err) + + query := "select * from bla" + + s := testutils.NewConnectServiceClientMock(nil, &client.ExecutePlanClient{ + SparkConnectService_ExecutePlanClient: &mocks.ProtoClient{ + RecvResponses: []*proto.ExecutePlanResponse{ + { + ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{ + SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{}, + }, + }, + { + ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ + ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ + RowCount: 1, + Data: buf.Bytes(), + }, + }, + }, + }, + }, + }, nil, nil, t) + c := client.NewSparkExecutorFromClient(s, nil, "") + + session := &sparkSessionImpl{ + client: c, + } + + resp, err := session.Sql(ctx, query) + assert.NoError(t, err) + assert.NotNil(t, resp) + writer, err := resp.Repartition(1, []string{"1"}) + assert.NoError(t, err) + collector := &testCollector{} + err = writer.WriteResult(ctx, collector, 1, false) + assert.NoError(t, err) + assert.Equal(t, []any{"str2"}, collector.row) +} From f5de587694f7d3a138456210c4ea85e3e5bc5577 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Sun, 7 Jul 2024 00:36:31 +0200 Subject: [PATCH 2/7] removing old file --- spark/sql/session/sparksession_test.go | 237 ------------------------- 1 file changed, 237 deletions(-) delete mode 100644 spark/sql/session/sparksession_test.go diff --git a/spark/sql/session/sparksession_test.go b/spark/sql/session/sparksession_test.go deleted file mode 100644 index 002c030..0000000 --- a/spark/sql/session/sparksession_test.go +++ /dev/null @@ -1,237 +0,0 @@ -package session - -import ( - "bytes" - "context" - "testing" - - "github.com/apache/arrow/go/v12/arrow" - "github.com/apache/arrow/go/v12/arrow/array" - "github.com/apache/arrow/go/v12/arrow/ipc" - "github.com/apache/arrow/go/v12/arrow/memory" - proto "github.com/apache/spark-connect-go/v35/internal/generated" - "github.com/apache/spark-connect-go/v35/spark/mocks" - "github.com/apache/spark-connect-go/v35/spark/sparkerrors" - "github.com/apache/spark-connect-go/v35/spark/sql" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "google.golang.org/grpc" -) - -type connectServiceClient struct { - t *testing.T - - analysePlanResponse *proto.AnalyzePlanResponse - executePlanClient proto.SparkConnectService_ExecutePlanClient - expectedExecutePlanRequest *proto.ExecutePlanRequest - - err error -} - -func (c *connectServiceClient) ExecutePlan(ctx context.Context, in *proto.ExecutePlanRequest, opts ...grpc.CallOption) (proto.SparkConnectService_ExecutePlanClient, error) { - if c.expectedExecutePlanRequest != nil { - assert.Equal(c.t, c.expectedExecutePlanRequest, in) - } - return c.executePlanClient, c.err -} - -func (c *connectServiceClient) AnalyzePlan(ctx context.Context, in *proto.AnalyzePlanRequest, opts ...grpc.CallOption) (*proto.AnalyzePlanResponse, error) { - return c.analysePlanResponse, c.err -} - -func (c *connectServiceClient) Config(ctx context.Context, in *proto.ConfigRequest, opts ...grpc.CallOption) (*proto.ConfigResponse, error) { - return nil, c.err -} - -func (c *connectServiceClient) AddArtifacts(ctx context.Context, opts ...grpc.CallOption) (proto.SparkConnectService_AddArtifactsClient, error) { - return nil, c.err -} - -func (c *connectServiceClient) ArtifactStatus(ctx context.Context, in *proto.ArtifactStatusesRequest, opts ...grpc.CallOption) (*proto.ArtifactStatusesResponse, error) { - return nil, c.err -} - -func (c *connectServiceClient) Interrupt(ctx context.Context, in *proto.InterruptRequest, opts ...grpc.CallOption) (*proto.InterruptResponse, error) { - return nil, c.err -} - -func (c *connectServiceClient) ReattachExecute(ctx context.Context, in *proto.ReattachExecuteRequest, opts ...grpc.CallOption) (proto.SparkConnectService_ReattachExecuteClient, error) { - return nil, c.err -} - -func (c *connectServiceClient) ReleaseExecute(ctx context.Context, in *proto.ReleaseExecuteRequest, opts ...grpc.CallOption) (*proto.ReleaseExecuteResponse, error) { - return nil, c.err -} - -func TestAnalyzePlanCallsAnalyzePlanOnClient(t *testing.T) { - ctx := context.Background() - reponse := &proto.AnalyzePlanResponse{} - session := &sparkSessionImpl{ - client: &connectServiceClient{ - analysePlanResponse: reponse, - }, - } - resp, err := session.AnalyzePlan(ctx, &proto.Plan{}) - assert.NoError(t, err) - assert.NotNil(t, resp) -} - -func TestAnalyzePlanFailsIfClientFails(t *testing.T) { - ctx := context.Background() - session := &sparkSessionImpl{ - client: &connectServiceClient{ - err: assert.AnError, - }, - } - resp, err := session.AnalyzePlan(ctx, &proto.Plan{}) - assert.Nil(t, resp) - assert.Error(t, err) -} - -func TestExecutePlanCallsExecutePlanOnClient(t *testing.T) { - ctx := context.Background() - - plan := &proto.Plan{} - request := &proto.ExecutePlanRequest{ - Plan: plan, - UserContext: &proto.UserContext{ - UserId: "na", - }, - } - session := &sparkSessionImpl{ - client: &connectServiceClient{ - executePlanClient: &sql.ExecutePlanClient{}, - expectedExecutePlanRequest: request, - t: t, - }, - } - resp, err := session.ExecutePlan(ctx, plan) - assert.NoError(t, err) - assert.NotNil(t, resp) -} - -func TestSQLCallsExecutePlanWithSQLOnClient(t *testing.T) { - ctx := context.Background() - - query := "select * from bla" - plan := &proto.Plan{ - OpType: &proto.Plan_Command{ - Command: &proto.Command{ - CommandType: &proto.Command_SqlCommand{ - SqlCommand: &proto.SqlCommand{ - Sql: query, - }, - }, - }, - }, - } - request := &proto.ExecutePlanRequest{ - Plan: plan, - UserContext: &proto.UserContext{ - UserId: "na", - }, - } - session := &sparkSessionImpl{ - client: &connectServiceClient{ - executePlanClient: &sql.ExecutePlanClient{&mocks.ProtoClient{ - RecvResponse: &proto.ExecutePlanResponse{ - ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{ - SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{}, - }, - }, - }}, - expectedExecutePlanRequest: request, - t: t, - }, - } - resp, err := session.Sql(ctx, query) - assert.NoError(t, err) - assert.NotNil(t, resp) -} - -func TestNewSessionBuilderCreatesASession(t *testing.T) { - ctx := context.Background() - spark, err := NewSessionBuilder().Remote("sc://connection").Build(ctx) - assert.NoError(t, err) - assert.NotNil(t, spark) -} - -func TestNewSessionBuilderFailsIfConnectionStringIsInvalid(t *testing.T) { - ctx := context.Background() - spark, err := NewSessionBuilder().Remote("invalid").Build(ctx) - assert.Error(t, err) - assert.ErrorIs(t, err, sparkerrors.InvalidInputError) - assert.Nil(t, spark) -} - -func TestWriteResultStreamsArrowResultToCollector(t *testing.T) { - ctx := context.Background() - - arrowFields := []arrow.Field{ - { - Name: "show_string", - Type: &arrow.StringType{}, - }, - } - arrowSchema := arrow.NewSchema(arrowFields, nil) - var buf bytes.Buffer - arrowWriter := ipc.NewWriter(&buf, ipc.WithSchema(arrowSchema)) - defer arrowWriter.Close() - - alloc := memory.NewGoAllocator() - recordBuilder := array.NewRecordBuilder(alloc, arrowSchema) - defer recordBuilder.Release() - - recordBuilder.Field(0).(*array.StringBuilder).Append("str1a\nstr1b") - recordBuilder.Field(0).(*array.StringBuilder).Append("str2") - - record := recordBuilder.NewRecord() - defer record.Release() - - err := arrowWriter.Write(record) - require.Nil(t, err) - - query := "select * from bla" - - session := &sparkSessionImpl{ - client: &connectServiceClient{ - executePlanClient: &sql.ExecutePlanClient{ - &mocks.ProtoClient{ - RecvResponses: []*proto.ExecutePlanResponse{ - { - ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{ - SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{}, - }, - }, - { - ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ - ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ - RowCount: 1, - Data: buf.Bytes(), - }, - }, - }, - }, - }, - }, - t: t, - }, - } - resp, err := session.Sql(ctx, query) - assert.NoError(t, err) - assert.NotNil(t, resp) - writer, err := resp.Repartition(1, []string{"1"}) - assert.NoError(t, err) - collector := &testCollector{} - err = writer.WriteResult(ctx, collector, 1, false) - assert.NoError(t, err) - assert.Equal(t, []any{"str2"}, collector.row) -} - -type testCollector struct { - row []any -} - -func (t *testCollector) WriteRow(values []any) { - t.row = values -} From b23ca71c4b4794766ec6cc5d0bf514bdbed0e217 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Mon, 8 Jul 2024 10:26:33 +0200 Subject: [PATCH 3/7] more refactoring and making tests tests --- .gitignore | 6 +- spark/client/client.go | 128 ++++++++++- spark/client/client_test.go | 72 +++--- spark/client/testutils/utils.go | 6 +- spark/mocks/mocks.go | 28 +-- spark/sparkerrors/errors.go | 17 ++ spark/sql/dataframe.go | 332 ++++------------------------ spark/sql/dataframe_test.go | 300 ------------------------- spark/sql/dataframereader.go | 2 +- spark/sql/dataframewriter_test.go | 26 ++- spark/sql/mocks_test.go | 10 + spark/sql/row.go | 10 +- spark/sql/row_test.go | 6 +- spark/sql/sparksession.go | 30 ++- spark/sql/sparksession_test.go | 81 +++++-- spark/sql/types/arrow.go | 182 +++++++++++++++ spark/sql/types/arrow_test.go | 312 ++++++++++++++++++++++++++ spark/sql/types/conversion.go | 69 ++++++ spark/sql/{ => types}/datatype.go | 2 +- spark/sql/{ => types}/structtype.go | 2 +- 20 files changed, 928 insertions(+), 693 deletions(-) create mode 100644 spark/sql/types/arrow.go create mode 100644 spark/sql/types/arrow_test.go create mode 100644 spark/sql/types/conversion.go rename spark/sql/{ => types}/datatype.go (99%) rename spark/sql/{ => types}/structtype.go (98%) diff --git a/.gitignore b/.gitignore index e76d6f0..1ecb27b 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,8 @@ coverage* # Ignore binaries cmd/spark-connect-example-raw-grpc-client/spark-connect-example-raw-grpc-client -cmd/spark-connect-example-spark-session/spark-connect-example-spark-session \ No newline at end of file +cmd/spark-connect-example-spark-session/spark-connect-example-spark-session + +target + +lib \ No newline at end of file diff --git a/spark/client/client.go b/spark/client/client.go index 270bf10..decd8ba 100644 --- a/spark/client/client.go +++ b/spark/client/client.go @@ -21,6 +21,10 @@ import ( "fmt" "io" + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/spark-connect-go/v35/spark/sql/types" + "github.com/apache/spark-connect-go/v35/internal/generated" proto "github.com/apache/spark-connect-go/v35/internal/generated" "github.com/apache/spark-connect-go/v35/spark/sparkerrors" @@ -34,6 +38,7 @@ import ( // RPC API level and the necessary translation of Arrow to Row objects. type SparkExecutor interface { ExecutePlan(ctx context.Context, plan *generated.Plan) (*ExecutePlanClient, error) + ExecuteCommand(ctx context.Context, plan *generated.Plan) (arrow.Table, *types.StructType, map[string]any, error) AnalyzePlan(ctx context.Context, plan *generated.Plan) (*generated.AnalyzePlanResponse, error) } @@ -43,6 +48,30 @@ type SparkExecutorImpl struct { sessionId string } +func (s *SparkExecutorImpl) ExecuteCommand(ctx context.Context, plan *proto.Plan) (arrow.Table, *types.StructType, map[string]any, error) { + request := proto.ExecutePlanRequest{ + SessionId: s.sessionId, + Plan: plan, + UserContext: &proto.UserContext{ + UserId: "na", + }, + } + + // Append the other items to the request. + ctx = metadata.NewOutgoingContext(ctx, s.metadata) + c, err := s.client.ExecutePlan(ctx, &request) + if err != nil { + return nil, nil, nil, sparkerrors.WithType(fmt.Errorf("failed to call ExecutePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError) + } + + respHandler := NewExecutePlanClient(c, s.sessionId) + schema, table, err := respHandler.ToTable() + if err != nil { + return nil, nil, nil, err + } + return table, schema, respHandler.properties, nil +} + func (s *SparkExecutorImpl) ExecutePlan(ctx context.Context, plan *proto.Plan) (*ExecutePlanClient, error) { request := proto.ExecutePlanRequest{ SessionId: s.sessionId, @@ -58,7 +87,7 @@ func (s *SparkExecutorImpl) ExecutePlan(ctx context.Context, plan *proto.Plan) ( if err != nil { return nil, sparkerrors.WithType(fmt.Errorf("failed to call ExecutePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError) } - return NewExecutePlanClient(c), nil + return NewExecutePlanClient(c, s.sessionId), nil } func (s *SparkExecutorImpl) AnalyzePlan(ctx context.Context, plan *proto.Plan) (*proto.AnalyzePlanResponse, error) { @@ -100,13 +129,102 @@ func NewSparkExecutorFromClient(client proto.SparkConnectServiceClient, md metad } } +// ExecutePlanClient is the wrapper around the result of the execution of a query plan using +// Spark Connect. type ExecutePlanClient struct { - generated.SparkConnectService_ExecutePlanClient + // The GRPC stream to read the response messages. + responseStream generated.SparkConnectService_ExecutePlanClient + // The schema of the result of the operation. + schema *types.StructType + // The sessionId is ised to verify the server side session. + sessionId string + done bool + properties map[string]any +} + +// In PySpark we have a generic toTable method that fetches all of the +// data and converts it to the desired format. +func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) { + var recordBatches []arrow.Record + var arrowSchema *arrow.Schema + recordBatches = make([]arrow.Record, 0) + + for { + resp, err := c.responseStream.Recv() + if err == io.EOF { + break + } + if err != nil { + return nil, nil, sparkerrors.WithType(fmt.Errorf("failed to receive plan execution response: %w", err), sparkerrors.ReadError) + } + + // Process the message + + // Check that the server returned the session ID that we were expecting + // and that it has not changed. + if resp.GetSessionId() != c.sessionId { + return c.schema, nil, sparkerrors.InvalidServerSideSessionError{ + OwnSessionId: c.sessionId, + ReceivedSessionId: resp.GetSessionId(), + } + } + + // Check if the response has already the schema set and if yes, convert + // the proto DataType to a StructType. + if resp.Schema != nil { + c.schema, err = types.ConvertProtoDataTypeToStructType(resp.Schema) + if err != nil { + return nil, nil, err + } + } + + switch x := resp.ResponseType.(type) { + case *proto.ExecutePlanResponse_SqlCommandResult_: + if val := x.SqlCommandResult.GetRelation(); val != nil { + c.properties["sql_command_result"] = val + } + case *proto.ExecutePlanResponse_ArrowBatch_: + // Do nothing. + record, err := types.ReadArrowBatchToRecord(x.ArrowBatch.Data, c.schema) + if err != nil { + return nil, nil, err + } + arrowSchema = record.Schema() + record.Retain() + recordBatches = append(recordBatches, record) + case *proto.ExecutePlanResponse_ResultComplete_: + c.done = true + default: + fmt.Printf("Received unsupported response ") + //return nil, nil, &sparkerrors.UnsupportedResponseTypeError{ + // ResponseType: x, + //} + } + } + + // Check that the result is logically complete. The result might not be complete + // because after 2 minutes the server will interrupt the connection and we have to + // send a ReAttach execute request. + //if !c.done { + // return nil, nil, sparkerrors.WithType(fmt.Errorf("the result is not complete"), sparkerrors.ExecutionError) + //} + // Return the schema and table. + if arrowSchema == nil { + return c.schema, nil, nil + } else { + return c.schema, array.NewTableFromRecords(arrowSchema, recordBatches), nil + } } -func NewExecutePlanClient(responseClient proto.SparkConnectService_ExecutePlanClient) *ExecutePlanClient { +func NewExecutePlanClient( + responseClient proto.SparkConnectService_ExecutePlanClient, + sessionId string, +) *ExecutePlanClient { return &ExecutePlanClient{ - responseClient, + responseStream: responseClient, + sessionId: sessionId, + done: false, + properties: make(map[string]any), } } @@ -116,7 +234,7 @@ func NewExecutePlanClient(responseClient proto.SparkConnectService_ExecutePlanCl // If we do not consume the returned GRPC stream, Spark Connect Driver will not really save data frame. func (c *ExecutePlanClient) ConsumeAll() error { for { - _, err := c.Recv() + _, err := c.responseStream.Recv() if err != nil { if errors.Is(err, io.EOF) { return nil diff --git a/spark/client/client_test.go b/spark/client/client_test.go index e601db8..060aa80 100644 --- a/spark/client/client_test.go +++ b/spark/client/client_test.go @@ -15,44 +15,34 @@ package client_test -import ( - "context" - "testing" - - proto "github.com/apache/spark-connect-go/v35/internal/generated" - "github.com/apache/spark-connect-go/v35/spark/client" - "github.com/apache/spark-connect-go/v35/spark/client/testutils" - "github.com/stretchr/testify/assert" -) - -func TestAnalyzePlanCallsAnalyzePlanOnClient(t *testing.T) { - ctx := context.Background() - response := &proto.AnalyzePlanResponse{} - c := client.NewSparkExecutorFromClient(testutils.NewConnectServiceClientMock(nil, nil, response, nil, nil), nil, "") - resp, err := c.AnalyzePlan(ctx, &proto.Plan{}) - assert.NoError(t, err) - assert.NotNil(t, resp) -} - -func TestAnalyzePlanFailsIfClientFails(t *testing.T) { - ctx := context.Background() - c := client.NewSparkExecutorFromClient(testutils.NewConnectServiceClientMock(nil, nil, nil, assert.AnError, nil), nil, "") - resp, err := c.AnalyzePlan(ctx, &proto.Plan{}) - assert.Nil(t, resp) - assert.Error(t, err) -} - -func TestExecutePlanCallsExecutePlanOnClient(t *testing.T) { - ctx := context.Background() - plan := &proto.Plan{} - request := &proto.ExecutePlanRequest{ - Plan: plan, - UserContext: &proto.UserContext{ - UserId: "na", - }, - } - c := client.NewSparkExecutorFromClient(testutils.NewConnectServiceClientMock(request, &client.ExecutePlanClient{}, nil, nil, t), nil, "") - resp, err := c.ExecutePlan(ctx, plan) - assert.NoError(t, err) - assert.NotNil(t, resp) -} +//func TestAnalyzePlanCallsAnalyzePlanOnClient(t *testing.T) { +// ctx := context.Background() +// response := &proto.AnalyzePlanResponse{} +// c := client.NewSparkExecutorFromClient(testutils.NewConnectServiceClientMock(nil, nil, response, nil, nil), nil, "") +// resp, err := c.AnalyzePlan(ctx, &proto.Plan{}) +// assert.NoError(t, err) +// assert.NotNil(t, resp) +//} +// +//func TestAnalyzePlanFailsIfClientFails(t *testing.T) { +// ctx := context.Background() +// c := client.NewSparkExecutorFromClient(testutils.NewConnectServiceClientMock(nil, nil, nil, assert.AnError, nil), nil, "") +// resp, err := c.AnalyzePlan(ctx, &proto.Plan{}) +// assert.Nil(t, resp) +// assert.Error(t, err) +//} +// +//func TestExecutePlanCallsExecutePlanOnClient(t *testing.T) { +// ctx := context.Background() +// plan := &proto.Plan{} +// request := &proto.ExecutePlanRequest{ +// Plan: plan, +// UserContext: &proto.UserContext{ +// UserId: "na", +// }, +// } +// c := client.NewSparkExecutorFromClient(testutils.NewConnectServiceClientMock(request, &client.ExecutePlanClient{}, nil, nil, t), nil, "") +// resp, err := c.ExecutePlan(ctx, plan) +// assert.NoError(t, err) +// assert.NotNil(t, resp) +//} diff --git a/spark/client/testutils/utils.go b/spark/client/testutils/utils.go index a0313d1..746bfdd 100644 --- a/spark/client/testutils/utils.go +++ b/spark/client/testutils/utils.go @@ -20,16 +20,16 @@ import ( "testing" proto "github.com/apache/spark-connect-go/v35/internal/generated" - "github.com/apache/spark-connect-go/v35/spark/client" "github.com/stretchr/testify/assert" "google.golang.org/grpc" ) +// connectServiceClient is a mock implementation of the SparkConnectServiceClient interface. type connectServiceClient struct { t *testing.T analysePlanResponse *proto.AnalyzePlanResponse - executePlanClient *client.ExecutePlanClient + executePlanClient proto.SparkConnectService_ExecutePlanClient expectedExecutePlanRequest *proto.ExecutePlanRequest err error @@ -70,7 +70,7 @@ func (c *connectServiceClient) ReleaseExecute(ctx context.Context, in *proto.Rel return nil, c.err } -func NewConnectServiceClientMock(epr *proto.ExecutePlanRequest, epc *client.ExecutePlanClient, apc *proto.AnalyzePlanResponse, err error, t *testing.T) proto.SparkConnectServiceClient { +func NewConnectServiceClientMock(epr *proto.ExecutePlanRequest, epc proto.SparkConnectService_ExecutePlanClient, apc *proto.AnalyzePlanResponse, err error, t *testing.T) proto.SparkConnectServiceClient { return &connectServiceClient{ t: t, expectedExecutePlanRequest: epr, diff --git a/spark/mocks/mocks.go b/spark/mocks/mocks.go index 90662aa..99c2d38 100644 --- a/spark/mocks/mocks.go +++ b/spark/mocks/mocks.go @@ -7,23 +7,25 @@ import ( "google.golang.org/grpc/metadata" ) -type ProtoClient struct { - RecvResponse *proto.ExecutePlanResponse - RecvResponses []*proto.ExecutePlanResponse +type MockResponse struct { + Resp *proto.ExecutePlanResponse + Err error +} - Err error +type ProtoClient struct { + // The stream of responses to return. + RecvResponse []*MockResponse + sent int } func (p *ProtoClient) Recv() (*proto.ExecutePlanResponse, error) { - if len(p.RecvResponses) != 0 { - p.RecvResponse = p.RecvResponses[0] - p.RecvResponses = p.RecvResponses[1:] - } - return p.RecvResponse, p.Err + val := p.RecvResponse[p.sent] + p.sent += 1 + return val.Resp, val.Err } func (p *ProtoClient) Header() (metadata.MD, error) { - return nil, p.Err + return nil, p.RecvResponse[p.sent].Err } func (p *ProtoClient) Trailer() metadata.MD { @@ -31,7 +33,7 @@ func (p *ProtoClient) Trailer() metadata.MD { } func (p *ProtoClient) CloseSend() error { - return p.Err + return p.RecvResponse[p.sent].Err } func (p *ProtoClient) Context() context.Context { @@ -39,9 +41,9 @@ func (p *ProtoClient) Context() context.Context { } func (p *ProtoClient) SendMsg(m interface{}) error { - return p.Err + return p.RecvResponse[p.sent].Err } func (p *ProtoClient) RecvMsg(m interface{}) error { - return p.Err + return p.RecvResponse[p.sent].Err } diff --git a/spark/sparkerrors/errors.go b/spark/sparkerrors/errors.go index 770cdef..d537fc0 100644 --- a/spark/sparkerrors/errors.go +++ b/spark/sparkerrors/errors.go @@ -47,3 +47,20 @@ var ( ExecutionError = errorType(errors.New("execution error")) InvalidInputError = errorType(errors.New("invalid input")) ) + +type UnsupportedResponseTypeError struct { + ResponseType interface{} +} + +func (e UnsupportedResponseTypeError) Error() string { + return fmt.Sprintf("Received unsupported response type: %T", e.ResponseType) +} + +type InvalidServerSideSessionError struct { + OwnSessionId string + ReceivedSessionId string +} + +func (e InvalidServerSideSessionError) Error() string { + return fmt.Sprintf("Received invalid session id %s, expected %s", e.ReceivedSessionId, e.OwnSessionId) +} diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go index fb0a8af..47ba276 100644 --- a/spark/sql/dataframe.go +++ b/spark/sql/dataframe.go @@ -17,15 +17,11 @@ package sql import ( - "bytes" "context" - "errors" "fmt" - "io" - "github.com/apache/arrow/go/v12/arrow" - "github.com/apache/arrow/go/v12/arrow/array" - "github.com/apache/arrow/go/v12/arrow/ipc" + "github.com/apache/spark-connect-go/v35/spark/sql/types" + proto "github.com/apache/spark-connect-go/v35/internal/generated" "github.com/apache/spark-connect-go/v35/spark/sparkerrors" ) @@ -43,7 +39,7 @@ type DataFrame interface { // Show uses WriteResult to write the data frames to the console output. Show(ctx context.Context, numRows int, truncate bool) error // Schema returns the schema for the current data frame. - Schema(ctx context.Context) (*StructType, error) + Schema(ctx context.Context) (*types.StructType, error) // Collect returns the data rows of the current data frame. Collect(ctx context.Context) ([]Row, error) // Writer returns a data frame writer, which could be used to save data frame to supported storage. @@ -52,7 +48,7 @@ type DataFrame interface { // Deprecated: Use Writer Write() DataFrameWriter // CreateTempView creates or replaces a temporary view. - CreateTempView(ctx context.Context, viewName string, replace bool, global bool) error + CreateTempView(ctx context.Context, viewName string, replace, global bool) error // Repartition re-partitions a data frame. Repartition(numPartitions int, columns []string) (DataFrame, error) // RepartitionByRange re-partitions a data frame by range partition. @@ -78,8 +74,7 @@ func NewDataFrame(session *sparkSessionImpl, relation *proto.Relation) DataFrame } } -type consoleCollector struct { -} +type consoleCollector struct{} func (c consoleCollector) WriteRow(values []any) { fmt.Println(values...) @@ -119,31 +114,42 @@ func (df *dataFrameImpl) WriteResult(ctx context.Context, collector ResultCollec return sparkerrors.WithType(fmt.Errorf("failed to show dataframe: %w", err), sparkerrors.ExecutionError) } - for { - response, err := responseClient.Recv() - if err != nil { - return sparkerrors.WithType(fmt.Errorf("failed to receive show response: %w", err), sparkerrors.ReadError) - } - arrowBatch := response.GetArrowBatch() - if arrowBatch == nil { - continue - } - err = showArrowBatch(arrowBatch, collector) + schema, table, err := responseClient.ToTable() + if err != nil { + return err + } + + var rows []Row + rows = make([]Row, table.NumRows()) + + values, err := types.ReadArrowTable(table) + if err != nil { + return err + } + + for idx, v := range values { + row := NewRowWithSchema(v, schema) + rows[idx] = row + } + + for _, row := range rows { + values, err := row.Values() if err != nil { - return err + return sparkerrors.WithType(fmt.Errorf("failed to get values in the row: %w", err), sparkerrors.ReadError) } - return nil + collector.WriteRow(values) } + return nil } -func (df *dataFrameImpl) Schema(ctx context.Context) (*StructType, error) { +func (df *dataFrameImpl) Schema(ctx context.Context) (*types.StructType, error) { response, err := df.session.client.AnalyzePlan(ctx, df.createPlan()) if err != nil { return nil, sparkerrors.WithType(fmt.Errorf("failed to analyze plan: %w", err), sparkerrors.ExecutionError) } responseSchema := response.GetSchema().Schema - return convertProtoDataTypeToStructType(responseSchema) + return types.ConvertProtoDataTypeToStructType(responseSchema) } func (df *dataFrameImpl) Collect(ctx context.Context) ([]Row, error) { @@ -152,43 +158,25 @@ func (df *dataFrameImpl) Collect(ctx context.Context) ([]Row, error) { return nil, sparkerrors.WithType(fmt.Errorf("failed to execute plan: %w", err), sparkerrors.ExecutionError) } - var schema *StructType - var allRows []Row - - for { - response, err := responseClient.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - return allRows, nil - } else { - return nil, sparkerrors.WithType(fmt.Errorf("failed to receive plan execution response: %w", err), sparkerrors.ReadError) - } - } - - dataType := response.GetSchema() - if dataType != nil { - schema, err = convertProtoDataTypeToStructType(dataType) - if err != nil { - return nil, err - } - continue - } + var schema *types.StructType + schema, table, err := responseClient.ToTable() + if err != nil { + return nil, err + } - arrowBatch := response.GetArrowBatch() - if arrowBatch == nil { - continue - } + var rows []Row + rows = make([]Row, table.NumRows()) - rowBatch, err := readArrowBatchData(arrowBatch.Data, schema) - if err != nil { - return nil, err - } + values, err := types.ReadArrowTable(table) + if err != nil { + return nil, err + } - if allRows == nil { - allRows = make([]Row, 0, len(rowBatch)) - } - allRows = append(allRows, rowBatch...) + for idx, v := range values { + row := NewRowWithSchema(v, schema) + rows[idx] = row } + return rows, nil } func (df *dataFrameImpl) Write() DataFrameWriter { @@ -199,7 +187,7 @@ func (df *dataFrameImpl) Writer() DataFrameWriter { return newDataFrameWriter(df.session, df.relation) } -func (df *dataFrameImpl) CreateTempView(ctx context.Context, viewName string, replace bool, global bool) error { +func (df *dataFrameImpl) CreateTempView(ctx context.Context, viewName string, replace, global bool) error { plan := &proto.Plan{ OpType: &proto.Plan_Command{ Command: &proto.Command{ @@ -220,7 +208,8 @@ func (df *dataFrameImpl) CreateTempView(ctx context.Context, viewName string, re return sparkerrors.WithType(fmt.Errorf("failed to create temp view %s: %w", viewName, err), sparkerrors.ExecutionError) } - return responseClient.ConsumeAll() + _, _, err = responseClient.ToTable() + return err } func (df *dataFrameImpl) Repartition(numPartitions int, columns []string) (DataFrame, error) { @@ -305,228 +294,3 @@ func (df *dataFrameImpl) repartitionByExpressions(numPartitions int, partitionEx } return NewDataFrame(df.session, newRelation), nil } - -func showArrowBatch(arrowBatch *proto.ExecutePlanResponse_ArrowBatch, collector ResultCollector) error { - return showArrowBatchData(arrowBatch.Data, collector) -} - -func showArrowBatchData(data []byte, collector ResultCollector) error { - rows, err := readArrowBatchData(data, nil) - if err != nil { - return err - } - for _, row := range rows { - values, err := row.Values() - if err != nil { - return sparkerrors.WithType(fmt.Errorf("failed to get values in the row: %w", err), sparkerrors.ReadError) - } - collector.WriteRow(values) - } - return nil -} - -func readArrowBatchData(data []byte, schema *StructType) ([]Row, error) { - reader := bytes.NewReader(data) - arrowReader, err := ipc.NewReader(reader) - if err != nil { - return nil, sparkerrors.WithType(fmt.Errorf("failed to create arrow reader: %w", err), sparkerrors.ReadError) - } - defer arrowReader.Release() - - var rows []Row - - for { - record, err := arrowReader.Read() - if err != nil { - if errors.Is(err, io.EOF) { - return rows, nil - } else { - return nil, sparkerrors.WithType(fmt.Errorf("failed to read arrow: %w", err), sparkerrors.ReadError) - } - } - - values, err := readArrowRecord(record) - if err != nil { - return nil, err - } - - numRows := int(record.NumRows()) - if rows == nil { - rows = make([]Row, 0, numRows) - } - - for _, v := range values { - row := NewRowWithSchema(v, schema) - rows = append(rows, row) - } - - hasNext := arrowReader.Next() - if !hasNext { - break - } - } - - return rows, nil -} - -// readArrowRecordColumn reads all values from arrow record and return [][]any -func readArrowRecord(record arrow.Record) ([][]any, error) { - numRows := record.NumRows() - numColumns := int(record.NumCols()) - - values := make([][]any, numRows) - for i := range values { - values[i] = make([]any, numColumns) - } - - for columnIndex := 0; columnIndex < numColumns; columnIndex++ { - err := readArrowRecordColumn(record, columnIndex, values) - if err != nil { - return nil, err - } - } - return values, nil -} - -// readArrowRecordColumn reads all values in a column and stores them in values -func readArrowRecordColumn(record arrow.Record, columnIndex int, values [][]any) error { - numRows := int(record.NumRows()) - columnData := record.Column(columnIndex).Data() - dataTypeId := columnData.DataType().ID() - switch dataTypeId { - case arrow.BOOL: - vector := array.NewBooleanData(columnData) - for rowIndex := 0; rowIndex < numRows; rowIndex++ { - values[rowIndex][columnIndex] = vector.Value(rowIndex) - } - case arrow.INT8: - vector := array.NewInt8Data(columnData) - for rowIndex := 0; rowIndex < numRows; rowIndex++ { - values[rowIndex][columnIndex] = vector.Value(rowIndex) - } - case arrow.INT16: - vector := array.NewInt16Data(columnData) - for rowIndex := 0; rowIndex < numRows; rowIndex++ { - values[rowIndex][columnIndex] = vector.Value(rowIndex) - } - case arrow.INT32: - vector := array.NewInt32Data(columnData) - for rowIndex := 0; rowIndex < numRows; rowIndex++ { - values[rowIndex][columnIndex] = vector.Value(rowIndex) - } - case arrow.INT64: - vector := array.NewInt64Data(columnData) - for rowIndex := 0; rowIndex < numRows; rowIndex++ { - values[rowIndex][columnIndex] = vector.Value(rowIndex) - } - case arrow.FLOAT16: - vector := array.NewFloat16Data(columnData) - for rowIndex := 0; rowIndex < numRows; rowIndex++ { - values[rowIndex][columnIndex] = vector.Value(rowIndex) - } - case arrow.FLOAT32: - vector := array.NewFloat32Data(columnData) - for rowIndex := 0; rowIndex < numRows; rowIndex++ { - values[rowIndex][columnIndex] = vector.Value(rowIndex) - } - case arrow.FLOAT64: - vector := array.NewFloat64Data(columnData) - for rowIndex := 0; rowIndex < numRows; rowIndex++ { - values[rowIndex][columnIndex] = vector.Value(rowIndex) - } - case arrow.DECIMAL | arrow.DECIMAL128: - vector := array.NewDecimal128Data(columnData) - for rowIndex := 0; rowIndex < numRows; rowIndex++ { - values[rowIndex][columnIndex] = vector.Value(rowIndex) - } - case arrow.DECIMAL256: - vector := array.NewDecimal256Data(columnData) - for rowIndex := 0; rowIndex < numRows; rowIndex++ { - values[rowIndex][columnIndex] = vector.Value(rowIndex) - } - case arrow.STRING: - vector := array.NewStringData(columnData) - for rowIndex := 0; rowIndex < numRows; rowIndex++ { - values[rowIndex][columnIndex] = vector.Value(rowIndex) - } - case arrow.BINARY: - vector := array.NewBinaryData(columnData) - for rowIndex := 0; rowIndex < numRows; rowIndex++ { - values[rowIndex][columnIndex] = vector.Value(rowIndex) - } - case arrow.TIMESTAMP: - vector := array.NewTimestampData(columnData) - for rowIndex := 0; rowIndex < numRows; rowIndex++ { - values[rowIndex][columnIndex] = vector.Value(rowIndex) - } - case arrow.DATE64: - vector := array.NewDate64Data(columnData) - for rowIndex := 0; rowIndex < numRows; rowIndex++ { - values[rowIndex][columnIndex] = vector.Value(rowIndex) - } - default: - return fmt.Errorf("unsupported arrow data type %s in column %d", dataTypeId.String(), columnIndex) - } - return nil -} - -func convertProtoDataTypeToStructType(input *proto.DataType) (*StructType, error) { - dataTypeStruct := input.GetStruct() - if dataTypeStruct == nil { - return nil, sparkerrors.WithType(errors.New("dataType.GetStruct() is nil"), sparkerrors.InvalidInputError) - } - return &StructType{ - Fields: convertProtoStructFields(dataTypeStruct.Fields), - }, nil -} - -func convertProtoStructFields(input []*proto.DataType_StructField) []StructField { - result := make([]StructField, len(input)) - for i, f := range input { - result[i] = convertProtoStructField(f) - } - return result -} - -func convertProtoStructField(field *proto.DataType_StructField) StructField { - return StructField{ - Name: field.Name, - DataType: convertProtoDataTypeToDataType(field.DataType), - } -} - -// convertProtoDataTypeToDataType converts protobuf data type to Spark connect sql data type -func convertProtoDataTypeToDataType(input *proto.DataType) DataType { - switch v := input.GetKind().(type) { - case *proto.DataType_Boolean_: - return BooleanType{} - case *proto.DataType_Byte_: - return ByteType{} - case *proto.DataType_Short_: - return ShortType{} - case *proto.DataType_Integer_: - return IntegerType{} - case *proto.DataType_Long_: - return LongType{} - case *proto.DataType_Float_: - return FloatType{} - case *proto.DataType_Double_: - return DoubleType{} - case *proto.DataType_Decimal_: - return DecimalType{} - case *proto.DataType_String_: - return StringType{} - case *proto.DataType_Binary_: - return BinaryType{} - case *proto.DataType_Timestamp_: - return TimestampType{} - case *proto.DataType_TimestampNtz: - return TimestampNtzType{} - case *proto.DataType_Date_: - return DateType{} - default: - return UnsupportedType{ - TypeInfo: v, - } - } -} diff --git a/spark/sql/dataframe_test.go b/spark/sql/dataframe_test.go index cccd703..831ad94 100644 --- a/spark/sql/dataframe_test.go +++ b/spark/sql/dataframe_test.go @@ -15,303 +15,3 @@ // limitations under the License. package sql - -import ( - "bytes" - "testing" - - "github.com/apache/arrow/go/v12/arrow" - "github.com/apache/arrow/go/v12/arrow/array" - "github.com/apache/arrow/go/v12/arrow/decimal128" - "github.com/apache/arrow/go/v12/arrow/decimal256" - "github.com/apache/arrow/go/v12/arrow/float16" - "github.com/apache/arrow/go/v12/arrow/ipc" - "github.com/apache/arrow/go/v12/arrow/memory" - proto "github.com/apache/spark-connect-go/v35/internal/generated" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestShowArrowBatchData(t *testing.T) { - arrowFields := []arrow.Field{ - { - Name: "show_string", - Type: &arrow.StringType{}, - }, - } - arrowSchema := arrow.NewSchema(arrowFields, nil) - var buf bytes.Buffer - arrowWriter := ipc.NewWriter(&buf, ipc.WithSchema(arrowSchema)) - defer arrowWriter.Close() - - alloc := memory.NewGoAllocator() - recordBuilder := array.NewRecordBuilder(alloc, arrowSchema) - defer recordBuilder.Release() - - recordBuilder.Field(0).(*array.StringBuilder).Append("str1a\nstr1b") - recordBuilder.Field(0).(*array.StringBuilder).Append("str2") - - record := recordBuilder.NewRecord() - defer record.Release() - - err := arrowWriter.Write(record) - require.Nil(t, err) - - collector := &testCollector{} - err = showArrowBatchData(buf.Bytes(), collector) - assert.Nil(t, err) - assert.Equal(t, []any{"str2"}, collector.row) -} - -func TestReadArrowRecord(t *testing.T) { - arrowFields := []arrow.Field{ - { - Name: "boolean_column", - Type: &arrow.BooleanType{}, - }, - { - Name: "int8_column", - Type: &arrow.Int8Type{}, - }, - { - Name: "int16_column", - Type: &arrow.Int16Type{}, - }, - { - Name: "int32_column", - Type: &arrow.Int32Type{}, - }, - { - Name: "int64_column", - Type: &arrow.Int64Type{}, - }, - { - Name: "float16_column", - Type: &arrow.Float16Type{}, - }, - { - Name: "float32_column", - Type: &arrow.Float32Type{}, - }, - { - Name: "float64_column", - Type: &arrow.Float64Type{}, - }, - { - Name: "decimal128_column", - Type: &arrow.Decimal128Type{}, - }, - { - Name: "decimal256_column", - Type: &arrow.Decimal256Type{}, - }, - { - Name: "string_column", - Type: &arrow.StringType{}, - }, - { - Name: "binary_column", - Type: &arrow.BinaryType{}, - }, - { - Name: "timestamp_column", - Type: &arrow.TimestampType{}, - }, - { - Name: "date64_column", - Type: &arrow.Date64Type{}, - }, - } - arrowSchema := arrow.NewSchema(arrowFields, nil) - var buf bytes.Buffer - arrowWriter := ipc.NewWriter(&buf, ipc.WithSchema(arrowSchema)) - defer arrowWriter.Close() - - alloc := memory.NewGoAllocator() - recordBuilder := array.NewRecordBuilder(alloc, arrowSchema) - defer recordBuilder.Release() - - i := 0 - recordBuilder.Field(i).(*array.BooleanBuilder).Append(false) - recordBuilder.Field(i).(*array.BooleanBuilder).Append(true) - - i++ - recordBuilder.Field(i).(*array.Int8Builder).Append(1) - recordBuilder.Field(i).(*array.Int8Builder).Append(2) - - i++ - recordBuilder.Field(i).(*array.Int16Builder).Append(10) - recordBuilder.Field(i).(*array.Int16Builder).Append(20) - - i++ - recordBuilder.Field(i).(*array.Int32Builder).Append(100) - recordBuilder.Field(i).(*array.Int32Builder).Append(200) - - i++ - recordBuilder.Field(i).(*array.Int64Builder).Append(1000) - recordBuilder.Field(i).(*array.Int64Builder).Append(2000) - - i++ - recordBuilder.Field(i).(*array.Float16Builder).Append(float16.New(10000.1)) - recordBuilder.Field(i).(*array.Float16Builder).Append(float16.New(20000.1)) - - i++ - recordBuilder.Field(i).(*array.Float32Builder).Append(100000.1) - recordBuilder.Field(i).(*array.Float32Builder).Append(200000.1) - - i++ - recordBuilder.Field(i).(*array.Float64Builder).Append(1000000.1) - recordBuilder.Field(i).(*array.Float64Builder).Append(2000000.1) - - i++ - recordBuilder.Field(i).(*array.Decimal128Builder).Append(decimal128.FromI64(10000000)) - recordBuilder.Field(i).(*array.Decimal128Builder).Append(decimal128.FromI64(20000000)) - - i++ - recordBuilder.Field(i).(*array.Decimal256Builder).Append(decimal256.FromI64(100000000)) - recordBuilder.Field(i).(*array.Decimal256Builder).Append(decimal256.FromI64(200000000)) - - i++ - recordBuilder.Field(i).(*array.StringBuilder).Append("str1") - recordBuilder.Field(i).(*array.StringBuilder).Append("str2") - - i++ - recordBuilder.Field(i).(*array.BinaryBuilder).Append([]byte("bytes1")) - recordBuilder.Field(i).(*array.BinaryBuilder).Append([]byte("bytes2")) - - i++ - recordBuilder.Field(i).(*array.TimestampBuilder).Append(arrow.Timestamp(1686981953115000)) - recordBuilder.Field(i).(*array.TimestampBuilder).Append(arrow.Timestamp(1686981953116000)) - - i++ - recordBuilder.Field(i).(*array.Date64Builder).Append(arrow.Date64(1686981953117000)) - recordBuilder.Field(i).(*array.Date64Builder).Append(arrow.Date64(1686981953118000)) - - record := recordBuilder.NewRecord() - defer record.Release() - - values, err := readArrowRecord(record) - require.Nil(t, err) - assert.Equal(t, 2, len(values)) - assert.Equal(t, []any{ - false, int8(1), int16(10), int32(100), int64(1000), - float16.New(10000.1), float32(100000.1), 1000000.1, - decimal128.FromI64(10000000), decimal256.FromI64(100000000), - "str1", []byte("bytes1"), - arrow.Timestamp(1686981953115000), arrow.Date64(1686981953117000)}, - values[0]) - assert.Equal(t, []any{ - true, int8(2), int16(20), int32(200), int64(2000), - float16.New(20000.1), float32(200000.1), 2000000.1, - decimal128.FromI64(20000000), decimal256.FromI64(200000000), - "str2", []byte("bytes2"), - arrow.Timestamp(1686981953116000), arrow.Date64(1686981953118000)}, - values[1]) -} - -func TestReadArrowRecord_UnsupportedType(t *testing.T) { - arrowFields := []arrow.Field{ - { - Name: "unsupported_type_column", - Type: &arrow.MonthIntervalType{}, - }, - } - arrowSchema := arrow.NewSchema(arrowFields, nil) - var buf bytes.Buffer - arrowWriter := ipc.NewWriter(&buf, ipc.WithSchema(arrowSchema)) - defer arrowWriter.Close() - - alloc := memory.NewGoAllocator() - recordBuilder := array.NewRecordBuilder(alloc, arrowSchema) - defer recordBuilder.Release() - - recordBuilder.Field(0).(*array.MonthIntervalBuilder).Append(1) - - record := recordBuilder.NewRecord() - defer record.Release() - - _, err := readArrowRecord(record) - require.NotNil(t, err) -} - -func TestConvertProtoDataTypeToDataType(t *testing.T) { - booleanDataType := &proto.DataType{ - Kind: &proto.DataType_Boolean_{}, - } - assert.Equal(t, "Boolean", convertProtoDataTypeToDataType(booleanDataType).TypeName()) - - byteDataType := &proto.DataType{ - Kind: &proto.DataType_Byte_{}, - } - assert.Equal(t, "Byte", convertProtoDataTypeToDataType(byteDataType).TypeName()) - - shortDataType := &proto.DataType{ - Kind: &proto.DataType_Short_{}, - } - assert.Equal(t, "Short", convertProtoDataTypeToDataType(shortDataType).TypeName()) - - integerDataType := &proto.DataType{ - Kind: &proto.DataType_Integer_{}, - } - assert.Equal(t, "Integer", convertProtoDataTypeToDataType(integerDataType).TypeName()) - - longDataType := &proto.DataType{ - Kind: &proto.DataType_Long_{}, - } - assert.Equal(t, "Long", convertProtoDataTypeToDataType(longDataType).TypeName()) - - floatDataType := &proto.DataType{ - Kind: &proto.DataType_Float_{}, - } - assert.Equal(t, "Float", convertProtoDataTypeToDataType(floatDataType).TypeName()) - - doubleDataType := &proto.DataType{ - Kind: &proto.DataType_Double_{}, - } - assert.Equal(t, "Double", convertProtoDataTypeToDataType(doubleDataType).TypeName()) - - decimalDataType := &proto.DataType{ - Kind: &proto.DataType_Decimal_{}, - } - assert.Equal(t, "Decimal", convertProtoDataTypeToDataType(decimalDataType).TypeName()) - - stringDataType := &proto.DataType{ - Kind: &proto.DataType_String_{}, - } - assert.Equal(t, "String", convertProtoDataTypeToDataType(stringDataType).TypeName()) - - binaryDataType := &proto.DataType{ - Kind: &proto.DataType_Binary_{}, - } - assert.Equal(t, "Binary", convertProtoDataTypeToDataType(binaryDataType).TypeName()) - - timestampDataType := &proto.DataType{ - Kind: &proto.DataType_Timestamp_{}, - } - assert.Equal(t, "Timestamp", convertProtoDataTypeToDataType(timestampDataType).TypeName()) - - timestampNtzDataType := &proto.DataType{ - Kind: &proto.DataType_TimestampNtz{}, - } - assert.Equal(t, "TimestampNtz", convertProtoDataTypeToDataType(timestampNtzDataType).TypeName()) - - dateDataType := &proto.DataType{ - Kind: &proto.DataType_Date_{}, - } - assert.Equal(t, "Date", convertProtoDataTypeToDataType(dateDataType).TypeName()) -} - -func TestConvertProtoDataTypeToDataType_UnsupportedType(t *testing.T) { - unsupportedDataType := &proto.DataType{ - Kind: &proto.DataType_YearMonthInterval_{}, - } - assert.Equal(t, "Unsupported", convertProtoDataTypeToDataType(unsupportedDataType).TypeName()) -} - -type testCollector struct { - row []any -} - -func (t *testCollector) WriteRow(values []any) { - t.row = values -} diff --git a/spark/sql/dataframereader.go b/spark/sql/dataframereader.go index 732ec7c..6d1dff8 100644 --- a/spark/sql/dataframereader.go +++ b/spark/sql/dataframereader.go @@ -40,7 +40,7 @@ func (w *dataFrameReaderImpl) Load(path string) (DataFrame, error) { return NewDataFrame(w.sparkSession, toRelation(path, format)), nil } -func toRelation(path string, format string) *proto.Relation { +func toRelation(path, format string) *proto.Relation { return &proto.Relation{ RelType: &proto.Relation_Read{ Read: &proto.Read{ diff --git a/spark/sql/dataframewriter_test.go b/spark/sql/dataframewriter_test.go index 3c20df0..93dd54d 100644 --- a/spark/sql/dataframewriter_test.go +++ b/spark/sql/dataframewriter_test.go @@ -5,6 +5,8 @@ import ( "io" "testing" + "github.com/google/uuid" + "github.com/apache/spark-connect-go/v35/spark/client" proto "github.com/apache/spark-connect-go/v35/internal/generated" @@ -42,11 +44,16 @@ func TestSaveExecutesWriteOperationUntilEOF(t *testing.T) { relation := &proto.Relation{} executor := &testExecutor{ client: client.NewExecutePlanClient(&mocks.ProtoClient{ - Err: io.EOF, - }), + RecvResponse: []*mocks.MockResponse{ + { + Err: io.EOF, + }, + }, + }, uuid.NewString()), } session := &sparkSessionImpl{ - client: executor, + client: executor, + sessionId: uuid.NewString(), } ctx := context.Background() path := "path" @@ -62,11 +69,16 @@ func TestSaveFailsIfAnotherErrorHappensWhenReadingStream(t *testing.T) { relation := &proto.Relation{} executor := &testExecutor{ client: client.NewExecutePlanClient(&mocks.ProtoClient{ - Err: assert.AnError, - }), + RecvResponse: []*mocks.MockResponse{ + { + Err: assert.AnError, + }, + }, + }, uuid.NewString()), } session := &sparkSessionImpl{ - client: executor, + client: executor, + sessionId: uuid.NewString(), } ctx := context.Background() path := "path" @@ -81,7 +93,7 @@ func TestSaveFailsIfAnotherErrorHappensWhenReadingStream(t *testing.T) { func TestSaveFailsIfAnotherErrorHappensWhenExecuting(t *testing.T) { relation := &proto.Relation{} executor := &testExecutor{ - client: client.NewExecutePlanClient(&mocks.ProtoClient{}), + client: client.NewExecutePlanClient(&mocks.ProtoClient{}, uuid.NewString()), err: assert.AnError, } session := &sparkSessionImpl{ diff --git a/spark/sql/mocks_test.go b/spark/sql/mocks_test.go index ca80b3a..25bcffa 100644 --- a/spark/sql/mocks_test.go +++ b/spark/sql/mocks_test.go @@ -3,6 +3,9 @@ package sql import ( "context" + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/spark-connect-go/v35/spark/sql/types" + client2 "github.com/apache/spark-connect-go/v35/spark/client" proto "github.com/apache/spark-connect-go/v35/internal/generated" @@ -24,3 +27,10 @@ func (t *testExecutor) ExecutePlan(ctx context.Context, plan *proto.Plan) (*clie func (t *testExecutor) AnalyzePlan(ctx context.Context, plan *proto.Plan) (*proto.AnalyzePlanResponse, error) { return t.response, nil } + +func (t *testExecutor) ExecuteCommand(ctx context.Context, plan *proto.Plan) (arrow.Table, *types.StructType, map[string]interface{}, error) { + if t.err != nil { + return nil, nil, nil, t.err + } + return nil, nil, nil, nil +} diff --git a/spark/sql/row.go b/spark/sql/row.go index bea2ab7..4f821af 100644 --- a/spark/sql/row.go +++ b/spark/sql/row.go @@ -16,10 +16,12 @@ package sql +import "github.com/apache/spark-connect-go/v35/spark/sql/types" + // Row represents a row in a DataFrame. type Row interface { // Schema returns the schema of the row. - Schema() (*StructType, error) + Schema() (*types.StructType, error) // Values returns the values of the row. Values() ([]any, error) } @@ -27,17 +29,17 @@ type Row interface { // genericRowWithSchema represents a row in a DataFrame with schema. type genericRowWithSchema struct { values []any - schema *StructType + schema *types.StructType } -func NewRowWithSchema(values []any, schema *StructType) Row { +func NewRowWithSchema(values []any, schema *types.StructType) Row { return &genericRowWithSchema{ values: values, schema: schema, } } -func (r *genericRowWithSchema) Schema() (*StructType, error) { +func (r *genericRowWithSchema) Schema() (*types.StructType, error) { return r.schema, nil } diff --git a/spark/sql/row_test.go b/spark/sql/row_test.go index 7ae4f97..bc14aa7 100644 --- a/spark/sql/row_test.go +++ b/spark/sql/row_test.go @@ -3,12 +3,14 @@ package sql import ( "testing" + "github.com/apache/spark-connect-go/v35/spark/sql/types" + "github.com/stretchr/testify/assert" ) func TestSchema(t *testing.T) { values := []any{1} - schema := &StructType{} + schema := &types.StructType{} row := NewRowWithSchema(values, schema) schema2, err := row.Schema() assert.NoError(t, err) @@ -17,7 +19,7 @@ func TestSchema(t *testing.T) { func TestValues(t *testing.T) { values := []any{1} - schema := &StructType{} + schema := &types.StructType{} row := NewRowWithSchema(values, schema) values2, err := row.Values() assert.NoError(t, err) diff --git a/spark/sql/sparksession.go b/spark/sql/sparksession.go index d6daa7a..b830c6e 100644 --- a/spark/sql/sparksession.go +++ b/spark/sql/sparksession.go @@ -90,7 +90,12 @@ func (s *sparkSessionImpl) Read() DataFrameReader { return NewDataframeReader(s) } +// Sql executes a sql query and returns the result as a DataFrame func (s *sparkSessionImpl) Sql(ctx context.Context, query string) (DataFrame, error) { + // Due to the nature of Spark, we have to first submit the SQL query immediately as a command + // to make sure that all side effects have been executed properly. If no side effects are present, + // then simply prepare this as a SQL relation. + plan := &proto.Plan{ OpType: &proto.Plan_Command{ Command: &proto.Command{ @@ -102,20 +107,25 @@ func (s *sparkSessionImpl) Sql(ctx context.Context, query string) (DataFrame, er }, }, } - responseClient, err := s.client.ExecutePlan(ctx, plan) + // We need an execute command here. + _, _, properties, err := s.client.ExecuteCommand(ctx, plan) if err != nil { return nil, sparkerrors.WithType(fmt.Errorf("failed to execute sql: %s: %w", query, err), sparkerrors.ExecutionError) } - for { - response, err := responseClient.Recv() - if err != nil { - return nil, sparkerrors.WithType(fmt.Errorf("failed to receive ExecutePlan response: %w", err), sparkerrors.ReadError) - } - sqlCommandResult := response.GetSqlCommandResult() - if sqlCommandResult == nil { - continue + + val, ok := properties["sql_command_result"] + if !ok { + plan := &proto.Relation{ + RelType: &proto.Relation_Sql{ + Sql: &proto.SQL{ + Query: query, + }, + }, } - return NewDataFrame(s, sqlCommandResult.GetRelation()), nil + return NewDataFrame(s, plan), nil + } else { + rel := val.(*proto.Relation) + return NewDataFrame(s, rel), nil } } diff --git a/spark/sql/sparksession_test.go b/spark/sql/sparksession_test.go index e02d8bd..775dce6 100644 --- a/spark/sql/sparksession_test.go +++ b/spark/sql/sparksession_test.go @@ -19,6 +19,7 @@ package sql import ( "bytes" "context" + "io" "testing" "github.com/apache/arrow/go/v12/arrow" @@ -56,14 +57,31 @@ func TestSQLCallsExecutePlanWithSQLOnClient(t *testing.T) { }, } - s := testutils.NewConnectServiceClientMock(request, &client.ExecutePlanClient{ - SparkConnectService_ExecutePlanClient: &mocks.ProtoClient{ - RecvResponse: &proto.ExecutePlanResponse{ + // Create the responses: + responses := []*mocks.MockResponse{ + { + Resp: &proto.ExecutePlanResponse{ ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{ SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{}, }, }, + Err: nil, + }, + { + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ResultComplete_{ + ResultComplete: &proto.ExecutePlanResponse_ResultComplete{}, + }, + }, + Err: nil, + }, + { + Err: io.EOF, }, + } + + s := testutils.NewConnectServiceClientMock(request, &mocks.ProtoClient{ + RecvResponse: responses, }, nil, nil, t) c := client.NewSparkExecutorFromClient(s, nil, "") @@ -119,24 +137,46 @@ func TestWriteResultStreamsArrowResultToCollector(t *testing.T) { query := "select * from bla" - s := testutils.NewConnectServiceClientMock(nil, &client.ExecutePlanClient{ - SparkConnectService_ExecutePlanClient: &mocks.ProtoClient{ - RecvResponses: []*proto.ExecutePlanResponse{ - { - ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{ - SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{}, - }, + // Create the responses: + responses := []*mocks.MockResponse{ + // The first stream of response is necessary for the SQL command. + { + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{ + SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{}, }, - { - ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ - ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ - RowCount: 1, - Data: buf.Bytes(), - }, + }, + Err: nil, + }, + { + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ResultComplete_{ + ResultComplete: &proto.ExecutePlanResponse_ResultComplete{}, + }, + }, + Err: nil, + }, + { + Err: io.EOF, + }, + // The second stream of responses is for the actual execution + { + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ + ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ + RowCount: 2, + Data: buf.Bytes(), }, }, }, }, + { + Err: io.EOF, + }, + } + + s := testutils.NewConnectServiceClientMock(nil, &mocks.ProtoClient{ + RecvResponse: responses, }, nil, nil, t) c := client.NewSparkExecutorFromClient(s, nil, "") @@ -147,10 +187,11 @@ func TestWriteResultStreamsArrowResultToCollector(t *testing.T) { resp, err := session.Sql(ctx, query) assert.NoError(t, err) assert.NotNil(t, resp) - writer, err := resp.Repartition(1, []string{"1"}) + df, err := resp.Repartition(1, []string{"1"}) + assert.NoError(t, err) + rows, err := df.Collect(ctx) assert.NoError(t, err) - collector := &testCollector{} - err = writer.WriteResult(ctx, collector, 1, false) + vals, err := rows[1].Values() assert.NoError(t, err) - assert.Equal(t, []any{"str2"}, collector.row) + assert.Equal(t, []any{"str2"}, vals) } diff --git a/spark/sql/types/arrow.go b/spark/sql/types/arrow.go new file mode 100644 index 0000000..5d13b36 --- /dev/null +++ b/spark/sql/types/arrow.go @@ -0,0 +1,182 @@ +package types + +import ( + "bytes" + "fmt" + + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/arrow/go/v12/arrow/ipc" + "github.com/apache/spark-connect-go/v35/spark/sparkerrors" +) + +func ReadArrowTable(table arrow.Table) ([][]any, error) { + numRows := table.NumRows() + numColumns := int(table.NumCols()) + + values := make([][]any, numRows) + for i := range values { + values[i] = make([]any, numColumns) + } + + for columnIndex := 0; columnIndex < numColumns; columnIndex++ { + err := ReadArrowRecordColumn(table, columnIndex, values) + if err != nil { + return nil, err + } + } + return values, nil +} + +// readArrowRecordColumn reads all values in a column and stores them in values +func ReadArrowRecordColumn(record arrow.Table, columnIndex int, values [][]any) error { + chunkedColumn := record.Column(columnIndex).Data() + dataTypeId := chunkedColumn.DataType().ID() + switch dataTypeId { + case arrow.BOOL: + rowIndex := 0 + for _, columnData := range chunkedColumn.Chunks() { + vector := array.NewBooleanData(columnData.Data()) + for i := 0; i < columnData.Len(); i++ { + values[rowIndex][columnIndex] = vector.Value(i) + rowIndex += 1 + } + } + case arrow.INT8: + rowIndex := 0 + for _, columnData := range chunkedColumn.Chunks() { + vector := array.NewInt8Data(columnData.Data()) + for i := 0; i < columnData.Len(); i++ { + values[rowIndex][columnIndex] = vector.Value(i) + rowIndex += 1 + } + } + case arrow.INT16: + rowIndex := 0 + for _, columnData := range chunkedColumn.Chunks() { + vector := array.NewInt16Data(columnData.Data()) + for i := 0; i < columnData.Len(); i++ { + values[rowIndex][columnIndex] = vector.Value(i) + rowIndex += 1 + } + } + case arrow.INT32: + rowIndex := 0 + for _, columnData := range chunkedColumn.Chunks() { + vector := array.NewInt32Data(columnData.Data()) + for i := 0; i < columnData.Len(); i++ { + values[rowIndex][columnIndex] = vector.Value(i) + rowIndex += 1 + } + } + case arrow.INT64: + rowIndex := 0 + for _, columnData := range chunkedColumn.Chunks() { + vector := array.NewInt64Data(columnData.Data()) + for i := 0; i < columnData.Len(); i++ { + values[rowIndex][columnIndex] = vector.Value(i) + rowIndex += 1 + } + } + case arrow.FLOAT16: + rowIndex := 0 + for _, columnData := range chunkedColumn.Chunks() { + vector := array.NewFloat16Data(columnData.Data()) + for i := 0; i < columnData.Len(); i++ { + values[rowIndex][columnIndex] = vector.Value(i) + rowIndex += 1 + } + } + case arrow.FLOAT32: + rowIndex := 0 + for _, columnData := range chunkedColumn.Chunks() { + vector := array.NewFloat32Data(columnData.Data()) + for i := 0; i < columnData.Len(); i++ { + values[rowIndex][columnIndex] = vector.Value(i) + rowIndex += 1 + } + } + case arrow.FLOAT64: + rowIndex := 0 + for _, columnData := range chunkedColumn.Chunks() { + vector := array.NewFloat64Data(columnData.Data()) + for i := 0; i < columnData.Len(); i++ { + values[rowIndex][columnIndex] = vector.Value(i) + rowIndex += 1 + } + } + case arrow.DECIMAL | arrow.DECIMAL128: + rowIndex := 0 + for _, columnData := range chunkedColumn.Chunks() { + vector := array.NewDecimal128Data(columnData.Data()) + for i := 0; i < columnData.Len(); i++ { + values[rowIndex][columnIndex] = vector.Value(i) + rowIndex += 1 + } + } + case arrow.DECIMAL256: + rowIndex := 0 + for _, columnData := range chunkedColumn.Chunks() { + vector := array.NewDecimal256Data(columnData.Data()) + for i := 0; i < columnData.Len(); i++ { + values[rowIndex][columnIndex] = vector.Value(i) + rowIndex += 1 + } + } + case arrow.STRING: + rowIndex := 0 + for _, columnData := range chunkedColumn.Chunks() { + vector := array.NewStringData(columnData.Data()) + for i := 0; i < columnData.Len(); i++ { + values[rowIndex][columnIndex] = vector.Value(i) + rowIndex += 1 + } + } + case arrow.BINARY: + rowIndex := 0 + for _, columnData := range chunkedColumn.Chunks() { + vector := array.NewBinaryData(columnData.Data()) + for i := 0; rowIndex < columnData.Len(); i++ { + values[rowIndex][columnIndex] = vector.Value(i) + rowIndex += 1 + } + } + case arrow.TIMESTAMP: + rowIndex := 0 + for _, columnData := range chunkedColumn.Chunks() { + vector := array.NewTimestampData(columnData.Data()) + for i := 0; i < columnData.Len(); i++ { + values[rowIndex][columnIndex] = vector.Value(i) + rowIndex += 1 + } + } + case arrow.DATE64: + rowIndex := 0 + for _, columnData := range chunkedColumn.Chunks() { + vector := array.NewDate64Data(columnData.Data()) + for i := 0; i < columnData.Len(); i++ { + values[rowIndex][columnIndex] = vector.Value(i) + rowIndex += 1 + } + } + default: + return fmt.Errorf("unsupported arrow data type %s in column %d", dataTypeId.String(), columnIndex) + } + return nil +} + +func ReadArrowBatchToRecord(data []byte, schema *StructType) (arrow.Record, error) { + reader := bytes.NewReader(data) + arrowReader, err := ipc.NewReader(reader) + if err != nil { + return nil, sparkerrors.WithType(fmt.Errorf("failed to create arrow reader: %w", err), sparkerrors.ReadError) + } + defer arrowReader.Release() + + record, err := arrowReader.Read() + record.Retain() + if err != nil { + return nil, sparkerrors.WithType(fmt.Errorf("failed to read arrow record: %w", err), sparkerrors.ReadError) + } + return record, nil +} diff --git a/spark/sql/types/arrow_test.go b/spark/sql/types/arrow_test.go new file mode 100644 index 0000000..81604da --- /dev/null +++ b/spark/sql/types/arrow_test.go @@ -0,0 +1,312 @@ +package types_test + +import ( + "bytes" + "testing" + + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/arrow/go/v12/arrow/decimal128" + "github.com/apache/arrow/go/v12/arrow/decimal256" + "github.com/apache/arrow/go/v12/arrow/float16" + "github.com/apache/arrow/go/v12/arrow/ipc" + "github.com/apache/arrow/go/v12/arrow/memory" + proto "github.com/apache/spark-connect-go/v35/internal/generated" + "github.com/apache/spark-connect-go/v35/spark/sql/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestShowArrowBatchData(t *testing.T) { + arrowFields := []arrow.Field{ + { + Name: "show_string", + Type: &arrow.StringType{}, + }, + } + arrowSchema := arrow.NewSchema(arrowFields, nil) + var buf bytes.Buffer + arrowWriter := ipc.NewWriter(&buf, ipc.WithSchema(arrowSchema)) + defer arrowWriter.Close() + + alloc := memory.NewGoAllocator() + recordBuilder := array.NewRecordBuilder(alloc, arrowSchema) + defer recordBuilder.Release() + + recordBuilder.Field(0).(*array.StringBuilder).Append("str1a\nstr1b") + recordBuilder.Field(0).(*array.StringBuilder).Append("str2") + + record := recordBuilder.NewRecord() + defer record.Release() + + err := arrowWriter.Write(record) + require.Nil(t, err) + + // Convert the data + record, err = types.ReadArrowBatchToRecord(buf.Bytes(), nil) + require.NoError(t, err) + + table := array.NewTableFromRecords(arrowSchema, []arrow.Record{record}) + values, err := types.ReadArrowTable(table) + require.Nil(t, err) + assert.Equal(t, 2, len(values)) + assert.Equal(t, []any{"str1a\nstr1b"}, values[0]) + assert.Equal(t, []any{"str2"}, values[1]) +} + +func TestReadArrowRecord(t *testing.T) { + arrowFields := []arrow.Field{ + { + Name: "boolean_column", + Type: &arrow.BooleanType{}, + }, + { + Name: "int8_column", + Type: &arrow.Int8Type{}, + }, + { + Name: "int16_column", + Type: &arrow.Int16Type{}, + }, + { + Name: "int32_column", + Type: &arrow.Int32Type{}, + }, + { + Name: "int64_column", + Type: &arrow.Int64Type{}, + }, + { + Name: "float16_column", + Type: &arrow.Float16Type{}, + }, + { + Name: "float32_column", + Type: &arrow.Float32Type{}, + }, + { + Name: "float64_column", + Type: &arrow.Float64Type{}, + }, + { + Name: "decimal128_column", + Type: &arrow.Decimal128Type{}, + }, + { + Name: "decimal256_column", + Type: &arrow.Decimal256Type{}, + }, + { + Name: "string_column", + Type: &arrow.StringType{}, + }, + { + Name: "binary_column", + Type: &arrow.BinaryType{}, + }, + { + Name: "timestamp_column", + Type: &arrow.TimestampType{}, + }, + { + Name: "date64_column", + Type: &arrow.Date64Type{}, + }, + } + arrowSchema := arrow.NewSchema(arrowFields, nil) + var buf bytes.Buffer + arrowWriter := ipc.NewWriter(&buf, ipc.WithSchema(arrowSchema)) + defer arrowWriter.Close() + + alloc := memory.NewGoAllocator() + recordBuilder := array.NewRecordBuilder(alloc, arrowSchema) + defer recordBuilder.Release() + + i := 0 + recordBuilder.Field(i).(*array.BooleanBuilder).Append(false) + recordBuilder.Field(i).(*array.BooleanBuilder).Append(true) + + i++ + recordBuilder.Field(i).(*array.Int8Builder).Append(1) + recordBuilder.Field(i).(*array.Int8Builder).Append(2) + + i++ + recordBuilder.Field(i).(*array.Int16Builder).Append(10) + recordBuilder.Field(i).(*array.Int16Builder).Append(20) + + i++ + recordBuilder.Field(i).(*array.Int32Builder).Append(100) + recordBuilder.Field(i).(*array.Int32Builder).Append(200) + + i++ + recordBuilder.Field(i).(*array.Int64Builder).Append(1000) + recordBuilder.Field(i).(*array.Int64Builder).Append(2000) + + i++ + recordBuilder.Field(i).(*array.Float16Builder).Append(float16.New(10000.1)) + recordBuilder.Field(i).(*array.Float16Builder).Append(float16.New(20000.1)) + + i++ + recordBuilder.Field(i).(*array.Float32Builder).Append(100000.1) + recordBuilder.Field(i).(*array.Float32Builder).Append(200000.1) + + i++ + recordBuilder.Field(i).(*array.Float64Builder).Append(1000000.1) + recordBuilder.Field(i).(*array.Float64Builder).Append(2000000.1) + + i++ + recordBuilder.Field(i).(*array.Decimal128Builder).Append(decimal128.FromI64(10000000)) + recordBuilder.Field(i).(*array.Decimal128Builder).Append(decimal128.FromI64(20000000)) + + i++ + recordBuilder.Field(i).(*array.Decimal256Builder).Append(decimal256.FromI64(100000000)) + recordBuilder.Field(i).(*array.Decimal256Builder).Append(decimal256.FromI64(200000000)) + + i++ + recordBuilder.Field(i).(*array.StringBuilder).Append("str1") + recordBuilder.Field(i).(*array.StringBuilder).Append("str2") + + i++ + recordBuilder.Field(i).(*array.BinaryBuilder).Append([]byte("bytes1")) + recordBuilder.Field(i).(*array.BinaryBuilder).Append([]byte("bytes2")) + + i++ + recordBuilder.Field(i).(*array.TimestampBuilder).Append(arrow.Timestamp(1686981953115000)) + recordBuilder.Field(i).(*array.TimestampBuilder).Append(arrow.Timestamp(1686981953116000)) + + i++ + recordBuilder.Field(i).(*array.Date64Builder).Append(arrow.Date64(1686981953117000)) + recordBuilder.Field(i).(*array.Date64Builder).Append(arrow.Date64(1686981953118000)) + + record := recordBuilder.NewRecord() + defer record.Release() + + table := array.NewTableFromRecords(arrowSchema, []arrow.Record{record}) + values, err := types.ReadArrowTable(table) + require.Nil(t, err) + assert.Equal(t, 2, len(values)) + assert.Equal(t, []any{ + false, int8(1), int16(10), int32(100), int64(1000), + float16.New(10000.1), float32(100000.1), 1000000.1, + decimal128.FromI64(10000000), decimal256.FromI64(100000000), + "str1", []byte("bytes1"), + arrow.Timestamp(1686981953115000), arrow.Date64(1686981953117000), + }, + values[0]) + assert.Equal(t, []any{ + true, int8(2), int16(20), int32(200), int64(2000), + float16.New(20000.1), float32(200000.1), 2000000.1, + decimal128.FromI64(20000000), decimal256.FromI64(200000000), + "str2", []byte("bytes2"), + arrow.Timestamp(1686981953116000), arrow.Date64(1686981953118000), + }, + values[1]) +} + +func TestReadArrowRecord_UnsupportedType(t *testing.T) { + arrowFields := []arrow.Field{ + { + Name: "unsupported_type_column", + Type: &arrow.MonthIntervalType{}, + }, + } + arrowSchema := arrow.NewSchema(arrowFields, nil) + var buf bytes.Buffer + arrowWriter := ipc.NewWriter(&buf, ipc.WithSchema(arrowSchema)) + defer arrowWriter.Close() + + alloc := memory.NewGoAllocator() + recordBuilder := array.NewRecordBuilder(alloc, arrowSchema) + defer recordBuilder.Release() + + recordBuilder.Field(0).(*array.MonthIntervalBuilder).Append(1) + + record := recordBuilder.NewRecord() + defer record.Release() + + table := array.NewTableFromRecords(arrowSchema, []arrow.Record{record}) + _, err := types.ReadArrowTable(table) + require.NotNil(t, err) +} + +func TestConvertProtoDataTypeToDataType(t *testing.T) { + booleanDataType := &proto.DataType{ + Kind: &proto.DataType_Boolean_{}, + } + assert.Equal(t, "Boolean", types.ConvertProtoDataTypeToDataType(booleanDataType).TypeName()) + + byteDataType := &proto.DataType{ + Kind: &proto.DataType_Byte_{}, + } + assert.Equal(t, "Byte", types.ConvertProtoDataTypeToDataType(byteDataType).TypeName()) + + shortDataType := &proto.DataType{ + Kind: &proto.DataType_Short_{}, + } + assert.Equal(t, "Short", types.ConvertProtoDataTypeToDataType(shortDataType).TypeName()) + + integerDataType := &proto.DataType{ + Kind: &proto.DataType_Integer_{}, + } + assert.Equal(t, "Integer", types.ConvertProtoDataTypeToDataType(integerDataType).TypeName()) + + longDataType := &proto.DataType{ + Kind: &proto.DataType_Long_{}, + } + assert.Equal(t, "Long", types.ConvertProtoDataTypeToDataType(longDataType).TypeName()) + + floatDataType := &proto.DataType{ + Kind: &proto.DataType_Float_{}, + } + assert.Equal(t, "Float", types.ConvertProtoDataTypeToDataType(floatDataType).TypeName()) + + doubleDataType := &proto.DataType{ + Kind: &proto.DataType_Double_{}, + } + assert.Equal(t, "Double", types.ConvertProtoDataTypeToDataType(doubleDataType).TypeName()) + + decimalDataType := &proto.DataType{ + Kind: &proto.DataType_Decimal_{}, + } + assert.Equal(t, "Decimal", types.ConvertProtoDataTypeToDataType(decimalDataType).TypeName()) + + stringDataType := &proto.DataType{ + Kind: &proto.DataType_String_{}, + } + assert.Equal(t, "String", types.ConvertProtoDataTypeToDataType(stringDataType).TypeName()) + + binaryDataType := &proto.DataType{ + Kind: &proto.DataType_Binary_{}, + } + assert.Equal(t, "Binary", types.ConvertProtoDataTypeToDataType(binaryDataType).TypeName()) + + timestampDataType := &proto.DataType{ + Kind: &proto.DataType_Timestamp_{}, + } + assert.Equal(t, "Timestamp", types.ConvertProtoDataTypeToDataType(timestampDataType).TypeName()) + + timestampNtzDataType := &proto.DataType{ + Kind: &proto.DataType_TimestampNtz{}, + } + assert.Equal(t, "TimestampNtz", types.ConvertProtoDataTypeToDataType(timestampNtzDataType).TypeName()) + + dateDataType := &proto.DataType{ + Kind: &proto.DataType_Date_{}, + } + assert.Equal(t, "Date", types.ConvertProtoDataTypeToDataType(dateDataType).TypeName()) +} + +func TestConvertProtoDataTypeToDataType_UnsupportedType(t *testing.T) { + unsupportedDataType := &proto.DataType{ + Kind: &proto.DataType_YearMonthInterval_{}, + } + assert.Equal(t, "Unsupported", types.ConvertProtoDataTypeToDataType(unsupportedDataType).TypeName()) +} + +type testCollector struct { + row []any +} + +func (t *testCollector) WriteRow(values []any) { + t.row = values +} diff --git a/spark/sql/types/conversion.go b/spark/sql/types/conversion.go new file mode 100644 index 0000000..35c372d --- /dev/null +++ b/spark/sql/types/conversion.go @@ -0,0 +1,69 @@ +package types + +import ( + "errors" + + "github.com/apache/spark-connect-go/v35/internal/generated" + "github.com/apache/spark-connect-go/v35/spark/sparkerrors" +) + +func ConvertProtoDataTypeToStructType(input *generated.DataType) (*StructType, error) { + dataTypeStruct := input.GetStruct() + if dataTypeStruct == nil { + return nil, sparkerrors.WithType(errors.New("dataType.GetStruct() is nil"), sparkerrors.InvalidInputError) + } + return &StructType{ + Fields: ConvertProtoStructFields(dataTypeStruct.Fields), + }, nil +} + +func ConvertProtoStructFields(input []*generated.DataType_StructField) []StructField { + result := make([]StructField, len(input)) + for i, f := range input { + result[i] = ConvertProtoStructField(f) + } + return result +} + +func ConvertProtoStructField(field *generated.DataType_StructField) StructField { + return StructField{ + Name: field.Name, + DataType: ConvertProtoDataTypeToDataType(field.DataType), + } +} + +// ConvertProtoDataTypeToDataType converts protobuf data type to Spark connect sql data type +func ConvertProtoDataTypeToDataType(input *generated.DataType) DataType { + switch v := input.GetKind().(type) { + case *generated.DataType_Boolean_: + return BooleanType{} + case *generated.DataType_Byte_: + return ByteType{} + case *generated.DataType_Short_: + return ShortType{} + case *generated.DataType_Integer_: + return IntegerType{} + case *generated.DataType_Long_: + return LongType{} + case *generated.DataType_Float_: + return FloatType{} + case *generated.DataType_Double_: + return DoubleType{} + case *generated.DataType_Decimal_: + return DecimalType{} + case *generated.DataType_String_: + return StringType{} + case *generated.DataType_Binary_: + return BinaryType{} + case *generated.DataType_Timestamp_: + return TimestampType{} + case *generated.DataType_TimestampNtz: + return TimestampNtzType{} + case *generated.DataType_Date_: + return DateType{} + default: + return UnsupportedType{ + TypeInfo: v, + } + } +} diff --git a/spark/sql/datatype.go b/spark/sql/types/datatype.go similarity index 99% rename from spark/sql/datatype.go rename to spark/sql/types/datatype.go index 398aa2c..46c3e57 100644 --- a/spark/sql/datatype.go +++ b/spark/sql/types/datatype.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package sql +package types import ( "fmt" diff --git a/spark/sql/structtype.go b/spark/sql/types/structtype.go similarity index 98% rename from spark/sql/structtype.go rename to spark/sql/types/structtype.go index fd75236..de52f28 100644 --- a/spark/sql/structtype.go +++ b/spark/sql/types/structtype.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package sql +package types // StructField represents a field in a StructType. type StructField struct { From 657dbe720505a113a31946cabc9637d5cd7e4b67 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Tue, 9 Jul 2024 07:46:43 +0200 Subject: [PATCH 4/7] adding more tests --- Makefile | 2 +- spark/sql/types/conversion_test.go | 61 ++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 spark/sql/types/conversion_test.go diff --git a/Makefile b/Makefile index 01fd6fb..ffbc2c5 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ # FIRST_GOPATH := $(firstword $(subst :, ,$(GOPATH))) -PKGS := $(shell go list ./... | grep -v /tests | grep -v /xcpb | grep -v /gpb) +PKGS := $(shell go list ./... | grep -v /tests | grep -v /xcpb | grep -v /gpb | grep -v /generated) GOFILES_NOVENDOR := $(shell find . -name vendor -prune -o -type f -name '*.go' -not -name '*.pb.go' -print) GOFILES_BUILD := $(shell find . -type f -name '*.go' -not -name '*_test.go') PROTOFILES := $(shell find . -name vendor -prune -o -type f -name '*.proto' -print) diff --git a/spark/sql/types/conversion_test.go b/spark/sql/types/conversion_test.go new file mode 100644 index 0000000..70260b9 --- /dev/null +++ b/spark/sql/types/conversion_test.go @@ -0,0 +1,61 @@ +package types_test + +import ( + "testing" + + proto "github.com/apache/spark-connect-go/v35/internal/generated" + "github.com/apache/spark-connect-go/v35/spark/sql/types" + "github.com/stretchr/testify/assert" +) + +func TestConvertProtoStructFieldSupported(t *testing.T) { + protoType := &proto.DataType{Kind: &proto.DataType_Integer_{}} + structField := &proto.DataType_StructField{ + Name: "test", + DataType: protoType, + Nullable: true, + } + + dt := types.ConvertProtoStructField(structField) + assert.Equal(t, "test", dt.Name) + assert.IsType(t, types.IntegerType{}, dt.DataType) +} + +func TestConvertProtoStructFieldUnsupported(t *testing.T) { + protoType := &proto.DataType{Kind: &proto.DataType_CalendarInterval_{}} + structField := &proto.DataType_StructField{ + Name: "test", + DataType: protoType, + Nullable: true, + } + + dt := types.ConvertProtoStructField(structField) + assert.Equal(t, "test", dt.Name) + assert.IsType(t, types.UnsupportedType{}, dt.DataType) +} + +func TestConvertProtoStructToGoStruct(t *testing.T) { + protoType := &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "test", + DataType: &proto.DataType{Kind: &proto.DataType_Integer_{}}, + Nullable: true, + }, + }, + }, + }, + } + structType, err := types.ConvertProtoDataTypeToStructType(protoType) + assert.NoError(t, err) + assert.Equal(t, 1, len(structType.Fields)) + assert.Equal(t, "test", structType.Fields[0].Name) + assert.IsType(t, types.IntegerType{}, structType.Fields[0].DataType) + + // Check for input type that is not a struct type and it returns an error. + protoType = &proto.DataType{Kind: &proto.DataType_Integer_{}} + structType, err = types.ConvertProtoDataTypeToStructType(protoType) + assert.Error(t, err) +} From dd011964e269f79b25a4b22edc5e7ef893b12251 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Tue, 9 Jul 2024 07:57:47 +0200 Subject: [PATCH 5/7] merging master --- .../main.go | 1 - spark/sql/dataframe.go | 6 ++-- spark/sql/executeplanclient.go | 36 ------------------- spark/sql/types/arrow.go | 16 +++++++++ spark/sql/types/arrow_test.go | 24 ++++++++----- spark/sql/types/conversion.go | 16 +++++++++ spark/sql/types/conversion_test.go | 16 +++++++++ 7 files changed, 66 insertions(+), 49 deletions(-) diff --git a/cmd/spark-connect-example-spark-session/main.go b/cmd/spark-connect-example-spark-session/main.go index 1fd6965..5e3a9a6 100644 --- a/cmd/spark-connect-example-spark-session/main.go +++ b/cmd/spark-connect-example-spark-session/main.go @@ -22,7 +22,6 @@ import ( "log" "github.com/apache/spark-connect-go/v35/spark/sql" - "github.com/apache/spark-connect-go/v35/spark/sql/session" "github.com/apache/spark-connect-go/v35/spark/sql/utils" ) diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go index 47ba276..273eaa6 100644 --- a/spark/sql/dataframe.go +++ b/spark/sql/dataframe.go @@ -119,8 +119,7 @@ func (df *dataFrameImpl) WriteResult(ctx context.Context, collector ResultCollec return err } - var rows []Row - rows = make([]Row, table.NumRows()) + rows := make([]Row, table.NumRows()) values, err := types.ReadArrowTable(table) if err != nil { @@ -164,8 +163,7 @@ func (df *dataFrameImpl) Collect(ctx context.Context) ([]Row, error) { return nil, err } - var rows []Row - rows = make([]Row, table.NumRows()) + rows := make([]Row, table.NumRows()) values, err := types.ReadArrowTable(table) if err != nil { diff --git a/spark/sql/executeplanclient.go b/spark/sql/executeplanclient.go index 81123b0..831ad94 100644 --- a/spark/sql/executeplanclient.go +++ b/spark/sql/executeplanclient.go @@ -15,39 +15,3 @@ // limitations under the License. package sql - -import ( - "errors" - "fmt" - "io" - - proto "github.com/apache/spark-connect-go/v35/internal/generated" - "github.com/apache/spark-connect-go/v35/spark/sparkerrors" -) - -type ExecutePlanClient struct { - proto.SparkConnectService_ExecutePlanClient -} - -func NewExecutePlanClient(responseClient proto.SparkConnectService_ExecutePlanClient) *ExecutePlanClient { - return &ExecutePlanClient{ - responseClient, - } -} - -// consumeAll reads through the returned GRPC stream from Spark Connect Driver. It will -// discard the returned data if there is no error. This is necessary for handling GRPC response for -// saving data frame, since such consuming will trigger Spark Connect Driver really saving data frame. -// If we do not consume the returned GRPC stream, Spark Connect Driver will not really save data frame. -func (c *ExecutePlanClient) consumeAll() error { - for { - _, err := c.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - return nil - } else { - return sparkerrors.WithType(fmt.Errorf("failed to receive plan execution response: %w", err), sparkerrors.ReadError) - } - } - } -} diff --git a/spark/sql/types/arrow.go b/spark/sql/types/arrow.go index 5d13b36..ef36a27 100644 --- a/spark/sql/types/arrow.go +++ b/spark/sql/types/arrow.go @@ -1,3 +1,19 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package types import ( diff --git a/spark/sql/types/arrow_test.go b/spark/sql/types/arrow_test.go index 81604da..bd750cb 100644 --- a/spark/sql/types/arrow_test.go +++ b/spark/sql/types/arrow_test.go @@ -1,3 +1,19 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package types_test import ( @@ -302,11 +318,3 @@ func TestConvertProtoDataTypeToDataType_UnsupportedType(t *testing.T) { } assert.Equal(t, "Unsupported", types.ConvertProtoDataTypeToDataType(unsupportedDataType).TypeName()) } - -type testCollector struct { - row []any -} - -func (t *testCollector) WriteRow(values []any) { - t.row = values -} diff --git a/spark/sql/types/conversion.go b/spark/sql/types/conversion.go index 35c372d..dcb6241 100644 --- a/spark/sql/types/conversion.go +++ b/spark/sql/types/conversion.go @@ -1,3 +1,19 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package types import ( diff --git a/spark/sql/types/conversion_test.go b/spark/sql/types/conversion_test.go index 70260b9..59946a3 100644 --- a/spark/sql/types/conversion_test.go +++ b/spark/sql/types/conversion_test.go @@ -1,3 +1,19 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package types_test import ( From 7c5fc0567af91c60fc9e7dbbf3323fb8e269cdf5 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Tue, 9 Jul 2024 08:04:35 +0200 Subject: [PATCH 6/7] merging master --- spark/sql/types/conversion_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/sql/types/conversion_test.go b/spark/sql/types/conversion_test.go index 59946a3..2fb0e36 100644 --- a/spark/sql/types/conversion_test.go +++ b/spark/sql/types/conversion_test.go @@ -72,6 +72,6 @@ func TestConvertProtoStructToGoStruct(t *testing.T) { // Check for input type that is not a struct type and it returns an error. protoType = &proto.DataType{Kind: &proto.DataType_Integer_{}} - structType, err = types.ConvertProtoDataTypeToStructType(protoType) + _, err = types.ConvertProtoDataTypeToStructType(protoType) assert.Error(t, err) } From 2b194a85458c0540389ef6839448e8eaa74ce14c Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Tue, 9 Jul 2024 10:41:28 +0200 Subject: [PATCH 7/7] adding more tests --- spark/client/client.go | 5 ++ spark/client/client_test.go | 104 +++++++++++++++++++++++++----------- spark/mocks/mocks.go | 34 ++++++++++++ 3 files changed, 112 insertions(+), 31 deletions(-) diff --git a/spark/client/client.go b/spark/client/client.go index decd8ba..de321ab 100644 --- a/spark/client/client.go +++ b/spark/client/client.go @@ -57,6 +57,11 @@ func (s *SparkExecutorImpl) ExecuteCommand(ctx context.Context, plan *proto.Plan }, } + // Check that the supplied plan is actually a command. + if plan.GetCommand() == nil { + return nil, nil, nil, sparkerrors.WithType(fmt.Errorf("the supplied plan does not contain a command"), sparkerrors.ExecutionError) + } + // Append the other items to the request. ctx = metadata.NewOutgoingContext(ctx, s.metadata) c, err := s.client.ExecutePlan(ctx, &request) diff --git a/spark/client/client_test.go b/spark/client/client_test.go index 060aa80..f024108 100644 --- a/spark/client/client_test.go +++ b/spark/client/client_test.go @@ -15,34 +15,76 @@ package client_test -//func TestAnalyzePlanCallsAnalyzePlanOnClient(t *testing.T) { -// ctx := context.Background() -// response := &proto.AnalyzePlanResponse{} -// c := client.NewSparkExecutorFromClient(testutils.NewConnectServiceClientMock(nil, nil, response, nil, nil), nil, "") -// resp, err := c.AnalyzePlan(ctx, &proto.Plan{}) -// assert.NoError(t, err) -// assert.NotNil(t, resp) -//} -// -//func TestAnalyzePlanFailsIfClientFails(t *testing.T) { -// ctx := context.Background() -// c := client.NewSparkExecutorFromClient(testutils.NewConnectServiceClientMock(nil, nil, nil, assert.AnError, nil), nil, "") -// resp, err := c.AnalyzePlan(ctx, &proto.Plan{}) -// assert.Nil(t, resp) -// assert.Error(t, err) -//} -// -//func TestExecutePlanCallsExecutePlanOnClient(t *testing.T) { -// ctx := context.Background() -// plan := &proto.Plan{} -// request := &proto.ExecutePlanRequest{ -// Plan: plan, -// UserContext: &proto.UserContext{ -// UserId: "na", -// }, -// } -// c := client.NewSparkExecutorFromClient(testutils.NewConnectServiceClientMock(request, &client.ExecutePlanClient{}, nil, nil, t), nil, "") -// resp, err := c.ExecutePlan(ctx, plan) -// assert.NoError(t, err) -// assert.NotNil(t, resp) -//} +import ( + "context" + "testing" + + proto "github.com/apache/spark-connect-go/v35/internal/generated" + "github.com/apache/spark-connect-go/v35/spark/client" + "github.com/apache/spark-connect-go/v35/spark/client/testutils" + "github.com/apache/spark-connect-go/v35/spark/mocks" + "github.com/apache/spark-connect-go/v35/spark/sparkerrors" + "github.com/stretchr/testify/assert" +) + +func TestAnalyzePlanCallsAnalyzePlanOnClient(t *testing.T) { + ctx := context.Background() + response := &proto.AnalyzePlanResponse{} + c := client.NewSparkExecutorFromClient(testutils.NewConnectServiceClientMock(nil, nil, response, nil, nil), nil, "") + resp, err := c.AnalyzePlan(ctx, &proto.Plan{}) + assert.NoError(t, err) + assert.NotNil(t, resp) +} + +func TestAnalyzePlanFailsIfClientFails(t *testing.T) { + ctx := context.Background() + c := client.NewSparkExecutorFromClient(testutils.NewConnectServiceClientMock(nil, nil, nil, assert.AnError, nil), nil, "") + resp, err := c.AnalyzePlan(ctx, &proto.Plan{}) + assert.Nil(t, resp) + assert.Error(t, err) +} + +func TestExecutePlanCallsExecutePlanOnClient(t *testing.T) { + ctx := context.Background() + plan := &proto.Plan{} + request := &proto.ExecutePlanRequest{ + Plan: plan, + UserContext: &proto.UserContext{ + UserId: "na", + }, + } + + // Generate a mock client + responseStream := mocks.NewProtoClientMock(&mocks.ExecutePlanResponseDone) + + c := client.NewSparkExecutorFromClient(testutils.NewConnectServiceClientMock(request, responseStream, nil, nil, t), nil, "") + resp, err := c.ExecutePlan(ctx, plan) + assert.NoError(t, err) + assert.NotNil(t, resp) +} + +func TestExecutePlanCallsExecuteCommandOnClient(t *testing.T) { + ctx := context.Background() + plan := &proto.Plan{} + request := &proto.ExecutePlanRequest{ + Plan: plan, + UserContext: &proto.UserContext{ + UserId: "na", + }, + } + + // Generate a mock client + responseStream := mocks.NewProtoClientMock(&mocks.ExecutePlanResponseDone, &mocks.ExecutePlanResponseEOF) + + // Check that the execution fails if no command is supplied. + c := client.NewSparkExecutorFromClient(testutils.NewConnectServiceClientMock(request, responseStream, nil, nil, t), nil, "") + _, _, _, err := c.ExecuteCommand(ctx, plan) + assert.ErrorIs(t, err, sparkerrors.ExecutionError) + + // Generate a command and the execution should succeed. + sqlCommand := mocks.NewSqlCommand("select range(10)") + request.Plan = sqlCommand + c = client.NewSparkExecutorFromClient(testutils.NewConnectServiceClientMock(request, responseStream, nil, nil, t), nil, "") + _, _, _, err = c.ExecuteCommand(ctx, sqlCommand) + assert.NoError(t, err) +} diff --git a/spark/mocks/mocks.go b/spark/mocks/mocks.go index 7946718..98272d6 100644 --- a/spark/mocks/mocks.go +++ b/spark/mocks/mocks.go @@ -18,6 +18,7 @@ package mocks import ( "context" + "io" proto "github.com/apache/spark-connect-go/v35/internal/generated" "google.golang.org/grpc/metadata" @@ -34,6 +35,25 @@ type ProtoClient struct { sent int } +// MockResponseDone is a response that indicates the plan execution is done. +var ExecutePlanResponseDone = MockResponse{ + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ResultComplete_{ + ResultComplete: &proto.ExecutePlanResponse_ResultComplete{}, + }, + }, + Err: nil, +} + +var ExecutePlanResponseEOF = MockResponse{ + Err: io.EOF, +} + +// NewProtoClientMock creates a new mock client that returns the given responses. +func NewProtoClientMock(responses ...*MockResponse) *ProtoClient { + return &ProtoClient{RecvResponse: responses} +} + func (p *ProtoClient) Recv() (*proto.ExecutePlanResponse, error) { val := p.RecvResponse[p.sent] p.sent += 1 @@ -63,3 +83,17 @@ func (p *ProtoClient) SendMsg(m interface{}) error { func (p *ProtoClient) RecvMsg(m interface{}) error { return p.RecvResponse[p.sent].Err } + +func NewSqlCommand(sql string) *proto.Plan { + return &proto.Plan{ + OpType: &proto.Plan_Command{ + Command: &proto.Command{ + CommandType: &proto.Command_SqlCommand{ + SqlCommand: &proto.SqlCommand{ + Sql: sql, + }, + }, + }, + }, + } +}