Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48754] Further refactoring to make the code more stable and accessible #33

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

FIRST_GOPATH := $(firstword $(subst :, ,$(GOPATH)))
PKGS := $(shell go list ./... | grep -v /tests | grep -v /xcpb | grep -v /gpb)
PKGS := $(shell go list ./... | grep -v /tests | grep -v /xcpb | grep -v /gpb | grep -v /generated)
GOFILES_NOVENDOR := $(shell find . -name vendor -prune -o -type f -name '*.go' -not -name '*.pb.go' -print)
GOFILES_BUILD := $(shell find . -type f -name '*.go' -not -name '*_test.go')
PROTOFILES := $(shell find . -name vendor -prune -o -type f -name '*.proto' -print)
Expand Down
34 changes: 3 additions & 31 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,37 +50,9 @@ See [Quick Start Guide](quick-start.md)

## High Level Design

Following [diagram](https://textik.com/#ac299c8f32c4c342) shows main code in current prototype:

```
+-------------------+
| |
| dataFrameImpl |
| |
+-------------------+
|
|
+
+-------------------+
| |
| sparkSessionImpl |
| |
+-------------------+
|
|
+
+---------------------------+ +----------------+
| | | |
| SparkConnectServiceClient |--------------+| Spark Driver |
| | | |
+---------------------------+ +----------------+
```

`SparkConnectServiceClient` is GRPC client which talks to Spark Driver. `sparkSessionImpl` generates `dataFrameImpl`
instances. `dataFrameImpl` uses the GRPC client in `sparkSessionImpl` to communicate with Spark Driver.

We will mimic the logic in Spark Connect Scala implementation, and adopt Go common practices, e.g. returning `error` object for
error handling.
The overall goal of the design is to find a good balance of principle of the least surprise for
develoeprs that are familiar with the APIs of Apache Spark and idiomatic Go usage. The high-level
structure of the packages follows roughly the PySpark giudance but with Go idioms.

## Contributing

Expand Down
3 changes: 1 addition & 2 deletions cmd/spark-connect-example-spark-session/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"log"

"github.com/apache/spark-connect-go/v35/spark/sql"
"github.com/apache/spark-connect-go/v35/spark/sql/session"
"github.com/apache/spark-connect-go/v35/spark/sql/utils"
)

Expand All @@ -32,7 +31,7 @@ var remote = flag.String("remote", "sc://localhost:15002",
func main() {
flag.Parse()
ctx := context.Background()
spark, err := session.NewSessionBuilder().Remote(*remote).Build(ctx)
spark, err := sql.NewSessionBuilder().Remote(*remote).Build(ctx)
if err != nil {
log.Fatalf("Failed: %s", err)
}
Expand Down
251 changes: 251 additions & 0 deletions spark/client/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package client

import (
"context"
"errors"
"fmt"
"io"

"github.com/apache/arrow/go/v12/arrow"
"github.com/apache/arrow/go/v12/arrow/array"
"github.com/apache/spark-connect-go/v35/spark/sql/types"

"github.com/apache/spark-connect-go/v35/internal/generated"
proto "github.com/apache/spark-connect-go/v35/internal/generated"
"github.com/apache/spark-connect-go/v35/spark/sparkerrors"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)

// SparkExecutor is the interface for executing a plan in Spark.
//
// This interface does not deal with the public Spark API abstractions but roughly deals on the
// RPC API level and the necessary translation of Arrow to Row objects.
type SparkExecutor interface {
ExecutePlan(ctx context.Context, plan *generated.Plan) (*ExecutePlanClient, error)
ExecuteCommand(ctx context.Context, plan *generated.Plan) (arrow.Table, *types.StructType, map[string]any, error)
AnalyzePlan(ctx context.Context, plan *generated.Plan) (*generated.AnalyzePlanResponse, error)
}

type SparkExecutorImpl struct {
client proto.SparkConnectServiceClient
metadata metadata.MD
sessionId string
}

func (s *SparkExecutorImpl) ExecuteCommand(ctx context.Context, plan *proto.Plan) (arrow.Table, *types.StructType, map[string]any, error) {
request := proto.ExecutePlanRequest{
SessionId: s.sessionId,
Plan: plan,
UserContext: &proto.UserContext{
UserId: "na",
},
}

// Check that the supplied plan is actually a command.
if plan.GetCommand() == nil {
return nil, nil, nil, sparkerrors.WithType(fmt.Errorf("the supplied plan does not contain a command"), sparkerrors.ExecutionError)
}

// Append the other items to the request.
ctx = metadata.NewOutgoingContext(ctx, s.metadata)
c, err := s.client.ExecutePlan(ctx, &request)
if err != nil {
return nil, nil, nil, sparkerrors.WithType(fmt.Errorf("failed to call ExecutePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError)
}

respHandler := NewExecutePlanClient(c, s.sessionId)
schema, table, err := respHandler.ToTable()
if err != nil {
return nil, nil, nil, err
}
return table, schema, respHandler.properties, nil
}

func (s *SparkExecutorImpl) ExecutePlan(ctx context.Context, plan *proto.Plan) (*ExecutePlanClient, error) {
request := proto.ExecutePlanRequest{
SessionId: s.sessionId,
Plan: plan,
UserContext: &proto.UserContext{
UserId: "na",
},
}

// Append the other items to the request.
ctx = metadata.NewOutgoingContext(ctx, s.metadata)
c, err := s.client.ExecutePlan(ctx, &request)
if err != nil {
return nil, sparkerrors.WithType(fmt.Errorf("failed to call ExecutePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError)
}
return NewExecutePlanClient(c, s.sessionId), nil
}

func (s *SparkExecutorImpl) AnalyzePlan(ctx context.Context, plan *proto.Plan) (*proto.AnalyzePlanResponse, error) {
request := proto.AnalyzePlanRequest{
SessionId: s.sessionId,
Analyze: &proto.AnalyzePlanRequest_Schema_{
Schema: &proto.AnalyzePlanRequest_Schema{
Plan: plan,
},
},
UserContext: &proto.UserContext{
UserId: "na",
},
}
// Append the other items to the request.
ctx = metadata.NewOutgoingContext(ctx, s.metadata)

response, err := s.client.AnalyzePlan(ctx, &request)
if err != nil {
return nil, sparkerrors.WithType(fmt.Errorf("failed to call AnalyzePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError)
}
return response, nil
}

func NewSparkExecutor(conn *grpc.ClientConn, md metadata.MD, sessionId string) SparkExecutor {
client := proto.NewSparkConnectServiceClient(conn)
return &SparkExecutorImpl{
client: client,
metadata: md,
sessionId: sessionId,
}
}

func NewSparkExecutorFromClient(client proto.SparkConnectServiceClient, md metadata.MD, sessionId string) SparkExecutor {
return &SparkExecutorImpl{
client: client,
metadata: md,
sessionId: sessionId,
}
}

// ExecutePlanClient is the wrapper around the result of the execution of a query plan using
// Spark Connect.
type ExecutePlanClient struct {
// The GRPC stream to read the response messages.
responseStream generated.SparkConnectService_ExecutePlanClient
// The schema of the result of the operation.
schema *types.StructType
// The sessionId is ised to verify the server side session.
sessionId string
done bool
properties map[string]any
}

// In PySpark we have a generic toTable method that fetches all of the
// data and converts it to the desired format.
func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) {
var recordBatches []arrow.Record
var arrowSchema *arrow.Schema
recordBatches = make([]arrow.Record, 0)

for {
resp, err := c.responseStream.Recv()
if err == io.EOF {
break
}
if err != nil {
return nil, nil, sparkerrors.WithType(fmt.Errorf("failed to receive plan execution response: %w", err), sparkerrors.ReadError)
}

// Process the message

// Check that the server returned the session ID that we were expecting
// and that it has not changed.
if resp.GetSessionId() != c.sessionId {
return c.schema, nil, sparkerrors.InvalidServerSideSessionError{
OwnSessionId: c.sessionId,
ReceivedSessionId: resp.GetSessionId(),
}
}

// Check if the response has already the schema set and if yes, convert
// the proto DataType to a StructType.
if resp.Schema != nil {
c.schema, err = types.ConvertProtoDataTypeToStructType(resp.Schema)
if err != nil {
return nil, nil, err
}
}

switch x := resp.ResponseType.(type) {
case *proto.ExecutePlanResponse_SqlCommandResult_:
if val := x.SqlCommandResult.GetRelation(); val != nil {
c.properties["sql_command_result"] = val
}
case *proto.ExecutePlanResponse_ArrowBatch_:
// Do nothing.
record, err := types.ReadArrowBatchToRecord(x.ArrowBatch.Data, c.schema)
if err != nil {
return nil, nil, err
}
arrowSchema = record.Schema()
record.Retain()
recordBatches = append(recordBatches, record)
case *proto.ExecutePlanResponse_ResultComplete_:
c.done = true
default:
fmt.Printf("Received unsupported response ")
//return nil, nil, &sparkerrors.UnsupportedResponseTypeError{
// ResponseType: x,
//}
}
}

// Check that the result is logically complete. The result might not be complete
// because after 2 minutes the server will interrupt the connection and we have to
// send a ReAttach execute request.
//if !c.done {
// return nil, nil, sparkerrors.WithType(fmt.Errorf("the result is not complete"), sparkerrors.ExecutionError)
//}
// Return the schema and table.
if arrowSchema == nil {
return c.schema, nil, nil
} else {
return c.schema, array.NewTableFromRecords(arrowSchema, recordBatches), nil
}
}

func NewExecutePlanClient(
responseClient proto.SparkConnectService_ExecutePlanClient,
sessionId string,
) *ExecutePlanClient {
return &ExecutePlanClient{
responseStream: responseClient,
sessionId: sessionId,
done: false,
properties: make(map[string]any),
}
}

// consumeAll reads through the returned GRPC stream from Spark Connect Driver. It will
// discard the returned data if there is no error. This is necessary for handling GRPC response for
// saving data frame, since such consuming will trigger Spark Connect Driver really saving data frame.
// If we do not consume the returned GRPC stream, Spark Connect Driver will not really save data frame.
func (c *ExecutePlanClient) ConsumeAll() error {
for {
_, err := c.responseStream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
return nil
} else {
return sparkerrors.WithType(fmt.Errorf("failed to receive plan execution response: %w", err), sparkerrors.ReadError)
}
}
}
}
Loading