Skip to content

Commit

Permalink
[SPARK-48982] Properly extract Spark Errors from GRPC Request
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Properly extract the Spark Exception information from the GRPC status and the associated error info details. The errors are then associated with a `SparkError` type that allows accessing the metadata.

### Why are the changes needed?
Compatibility

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Added new UT.

Closes #36 from grundprinzip/SPARK-48982.

Authored-by: Martin Grund <martin.grund@databricks.com>
Signed-off-by: Martin Grund <martin.grund@databricks.com>
  • Loading branch information
grundprinzip committed Jul 24, 2024
1 parent 53797bd commit 957a4b3
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 27 deletions.
46 changes: 23 additions & 23 deletions cmd/spark-connect-example-spark-session/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,29 +39,29 @@ func main() {
}
defer utils.WarnOnError(spark.Stop, func(err error) {})

//df, err := spark.Sql(ctx, "select * from range(100)")
//if err != nil {
// log.Fatalf("Failed: %s", err)
//}
//
//df, _ = df.FilterByString("id < 10")
//err = df.Show(ctx, 100, false)
//if err != nil {
// log.Fatalf("Failed: %s", err)
//}
//
//df, err = spark.Sql(ctx, "select * from range(100)")
//if err != nil {
// log.Fatalf("Failed: %s", err)
//}
//
//df, _ = df.Filter(functions.Col("id").Lt(functions.Expr("10")))
//err = df.Show(ctx, 100, false)
//if err != nil {
// log.Fatalf("Failed: %s", err)
//}

df, _ := spark.Sql(ctx, "select * from range(100)")
df, err := spark.Sql(ctx, "select id2 from range(100)")
if err != nil {
log.Fatalf("Failed: %s", err)
}

df, _ = df.FilterByString("id < 10")
err = df.Show(ctx, 100, false)
if err != nil {
log.Fatalf("Failed: %s", err)
}

df, err = spark.Sql(ctx, "select * from range(100)")
if err != nil {
log.Fatalf("Failed: %s", err)
}

df, _ = df.Filter(functions.Col("id").Lt(functions.Expr("10")))
err = df.Show(ctx, 100, false)
if err != nil {
log.Fatalf("Failed: %s", err)
}

df, _ = spark.Sql(ctx, "select * from range(100)")
df, err = df.Filter(functions.Col("id").Lt(functions.Lit(20)))
if err != nil {
log.Fatalf("Failed: %s", err)
Expand Down
14 changes: 10 additions & 4 deletions spark/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ func (s *SparkExecutorImpl) AnalyzePlan(ctx context.Context, plan *proto.Plan) (
ctx = metadata.NewOutgoingContext(ctx, s.metadata)

response, err := s.client.AnalyzePlan(ctx, &request)
if err != nil {
return nil, sparkerrors.WithType(fmt.Errorf("failed to call AnalyzePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError)
if se := sparkerrors.FromRPCError(err); se != nil {
return nil, sparkerrors.WithType(se, sparkerrors.ExecutionError)
}
return response, nil
}
Expand All @@ -126,6 +126,8 @@ func NewSparkExecutor(conn *grpc.ClientConn, md metadata.MD, sessionId string) S
}
}

// NewSparkExecutorFromClient creates a new SparkExecutor from an existing client and is mostly
// used in testing.
func NewSparkExecutorFromClient(client proto.SparkConnectServiceClient, md metadata.MD, sessionId string) SparkExecutor {
return &SparkExecutorImpl{
client: client,
Expand Down Expand Up @@ -156,11 +158,15 @@ func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) {

for {
resp, err := c.responseStream.Recv()
// EOF is received when the last message has been processed and the stream
// finished normally.
if err == io.EOF {
break
}
if err != nil {
return nil, nil, sparkerrors.WithType(fmt.Errorf("failed to receive plan execution response: %w", err), sparkerrors.ReadError)

// If the error was not EOF, there might be another error.
if se := sparkerrors.FromRPCError(err); se != nil {
return nil, nil, sparkerrors.WithType(se, sparkerrors.ExecutionError)
}

// Process the message
Expand Down
73 changes: 73 additions & 0 deletions spark/sparkerrors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@
package sparkerrors

import (
"encoding/json"
"errors"
"fmt"

"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

type wrappedError struct {
Expand Down Expand Up @@ -65,3 +70,71 @@ type InvalidServerSideSessionError struct {
func (e InvalidServerSideSessionError) Error() string {
return fmt.Sprintf("Received invalid session id %s, expected %s", e.ReceivedSessionId, e.OwnSessionId)
}

// SparkError represents an error that is returned from Spark itself. It captures details of the
// error that allows better understanding about the error. This allows us to check if the error
// can be retried or not.
type SparkError struct {
// SqlState is the SQL state of the error.
SqlState string
// ErrorClass is the class of the error.
ErrorClass string
// If set is typically the classname throwing the error on the Spark side.
Reason string
// Message is the human-readable message of the error.
Message string
// Code is the gRPC status code of the error.
Code codes.Code
// ErrorId is the unique id of the error. It can be used to fetch more details about
// the error using an additional RPC from the server.
ErrorId string
// Parameters are the parameters that are used to format the error message.
Parameters map[string]string
status *status.Status
}

func (e SparkError) Error() string {
if e.Code == codes.Internal && e.SqlState != "" {
return fmt.Sprintf("[%s] %s. SQLSTATE: %s", e.ErrorClass, e.Message, e.SqlState)
} else {
return fmt.Sprintf("[%s] %s", e.Code.String(), e.Message)
}
}

// FromRPCError converts a gRPC error to a SparkError. If the error is not a gRPC error, it will
// create a plain "UNKNOWN" GRPC status type. If no error was observed returns nil.
func FromRPCError(e error) *SparkError {
status := status.Convert(e)
// If there was no error, simply pass through.
if status == nil {
return nil
}
result := &SparkError{
Message: status.Message(),
Code: status.Code(),
status: status,
}

// Now lets, check if we can extract the error info from the details.
for _, d := range status.Details() {
switch info := d.(type) {
case *errdetails.ErrorInfo:
// Parse the parameters from the error details, but only parse them if
// they're present.
var params map[string]string
if v, ok := info.GetMetadata()["messageParameters"]; ok {
err := json.Unmarshal([]byte(v), &params)
if err == nil {
// The message parameters is properly formatted JSON, if for some reason
// this is not the case, errors are ignored.
result.Parameters = params
}
}
result.SqlState = info.GetMetadata()["sqlState"]
result.ErrorClass = info.GetMetadata()["errorClass"]
result.ErrorId = info.GetMetadata()["errorId"]
result.Reason = info.Reason
}
}
return result
}
149 changes: 149 additions & 0 deletions spark/sparkerrors/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ package sparkerrors
import (
"testing"

"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/stretchr/testify/assert"
)

Expand All @@ -30,3 +34,148 @@ func TestErrorStringContainsErrorType(t *testing.T) {
err := WithType(assert.AnError, ConnectionError)
assert.Contains(t, err.Error(), ConnectionError.Error())
}

func TestGRPCErrorConversion(t *testing.T) {
err := status.Error(codes.Internal, "invalid argument")
se := FromRPCError(err)
assert.Equal(t, se.Code, codes.Internal)
assert.Equal(t, se.Message, "invalid argument")
}

func TestNonGRPCErrorsAreConvertedAsWell(t *testing.T) {
err := assert.AnError
se := FromRPCError(err)
assert.Equal(t, se.Code, codes.Unknown)
assert.Equal(t, se.Message, assert.AnError.Error())
}

func TestErrorDetailsExtractionFromGRPCStatus(t *testing.T) {
status := status.New(codes.Internal, "AnalysisException")
status, _ = status.WithDetails(&errdetails.ErrorInfo{
Reason: "AnalysisException",
Domain: "spark.sql",
Metadata: map[string]string{},
})

err := status.Err()
se := FromRPCError(err)
assert.Equal(t, codes.Internal, se.Code)
assert.Equal(t, "AnalysisException", se.Message)
assert.Equal(t, "AnalysisException", se.Reason)
}

func TestErrorDetailsWithSqlStateAndClass(t *testing.T) {
status := status.New(codes.Internal, "AnalysisException")
status, _ = status.WithDetails(&errdetails.ErrorInfo{
Reason: "AnalysisException",
Domain: "spark.sql",
Metadata: map[string]string{
"sqlState": "42000",
"errorClass": "ERROR_CLASS",
"errorId": "errorId",
"messageParameters": "",
},
})

err := status.Err()
se := FromRPCError(err)
assert.Equal(t, codes.Internal, se.Code)
assert.Equal(t, "AnalysisException", se.Message)
assert.Equal(t, "AnalysisException", se.Reason)
assert.Equal(t, "42000", se.SqlState)
assert.Equal(t, "ERROR_CLASS", se.ErrorClass)
assert.Equal(t, "errorId", se.ErrorId)
}

func TestErrorDetailsWithMessageParameterParsing(t *testing.T) {
type param struct {
TestName string
Input string
Expected map[string]string
}

params := []param{
{"empty input", "", nil},
{"empty input", "{", nil},
{"parse error", "{}", map[string]string{}},
{"valid input", "{\"key\":\"value\"}", map[string]string{"key": "value"}},
}

for _, p := range params {
t.Run(p.TestName, func(t *testing.T) {
status := status.New(codes.Internal, "AnalysisException")
status, _ = status.WithDetails(&errdetails.ErrorInfo{
Reason: "AnalysisException",
Domain: "spark.sql",
Metadata: map[string]string{
"sqlState": "42000",
"errorClass": "ERROR_CLASS",
"errorId": "errorId",
"messageParameters": p.Input,
},
})

err := status.Err()
se := FromRPCError(err)
assert.Equal(t, codes.Internal, se.Code)
assert.Equal(t, "AnalysisException", se.Message)
assert.Equal(t, "AnalysisException", se.Reason)
assert.Equal(t, "42000", se.SqlState)
assert.Equal(t, "ERROR_CLASS", se.ErrorClass)
assert.Equal(t, "errorId", se.ErrorId)
assert.Equal(t, p.Expected, se.Parameters)
})
}
}

func TestSparkError_Error(t *testing.T) {
type fields struct {
SqlState string
ErrorClass string
Reason string
Message string
Code codes.Code
ErrorId string
Parameters map[string]string
status *status.Status
}
tests := []struct {
name string
fields fields
want string
}{
{
"UNKNOWN",
fields{
Code: codes.Unknown,
Message: "Unknown error",
},
"[Unknown] Unknown error",
},
{
"Analysis Exception",
fields{
SqlState: "42703",
ErrorClass: "UNRESOLVED_COLUMN.WITH_SUGGESTION",
Message: "A column, variable, or function parameter with name `id2` cannot be resolved. Did you mean one of the following? [`id`]",
Code: codes.Internal,
},
"[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column, variable, or function parameter with name `id2` cannot be resolved. Did you mean one of the following? [`id`]. SQLSTATE: 42703",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := SparkError{
SqlState: tt.fields.SqlState,
ErrorClass: tt.fields.ErrorClass,
Reason: tt.fields.Reason,
Message: tt.fields.Message,
Code: tt.fields.Code,
ErrorId: tt.fields.ErrorId,
Parameters: tt.fields.Parameters,
status: tt.fields.status,
}
assert.Equalf(t, tt.want, e.Error(), "Error()")
})
}
}

0 comments on commit 957a4b3

Please sign in to comment.