diff --git a/cmd/spark-connect-example-spark-session/main.go b/cmd/spark-connect-example-spark-session/main.go index 5f63bcc..71f6f07 100644 --- a/cmd/spark-connect-example-spark-session/main.go +++ b/cmd/spark-connect-example-spark-session/main.go @@ -39,29 +39,29 @@ func main() { } defer utils.WarnOnError(spark.Stop, func(err error) {}) - //df, err := spark.Sql(ctx, "select * from range(100)") - //if err != nil { - // log.Fatalf("Failed: %s", err) - //} - // - //df, _ = df.FilterByString("id < 10") - //err = df.Show(ctx, 100, false) - //if err != nil { - // log.Fatalf("Failed: %s", err) - //} - // - //df, err = spark.Sql(ctx, "select * from range(100)") - //if err != nil { - // log.Fatalf("Failed: %s", err) - //} - // - //df, _ = df.Filter(functions.Col("id").Lt(functions.Expr("10"))) - //err = df.Show(ctx, 100, false) - //if err != nil { - // log.Fatalf("Failed: %s", err) - //} - - df, _ := spark.Sql(ctx, "select * from range(100)") + df, err := spark.Sql(ctx, "select id2 from range(100)") + if err != nil { + log.Fatalf("Failed: %s", err) + } + + df, _ = df.FilterByString("id < 10") + err = df.Show(ctx, 100, false) + if err != nil { + log.Fatalf("Failed: %s", err) + } + + df, err = spark.Sql(ctx, "select * from range(100)") + if err != nil { + log.Fatalf("Failed: %s", err) + } + + df, _ = df.Filter(functions.Col("id").Lt(functions.Expr("10"))) + err = df.Show(ctx, 100, false) + if err != nil { + log.Fatalf("Failed: %s", err) + } + + df, _ = spark.Sql(ctx, "select * from range(100)") df, err = df.Filter(functions.Col("id").Lt(functions.Lit(20))) if err != nil { log.Fatalf("Failed: %s", err) diff --git a/spark/client/client.go b/spark/client/client.go index ed65f44..a9b163e 100644 --- a/spark/client/client.go +++ b/spark/client/client.go @@ -111,8 +111,8 @@ func (s *SparkExecutorImpl) AnalyzePlan(ctx context.Context, plan *proto.Plan) ( ctx = metadata.NewOutgoingContext(ctx, s.metadata) response, err := s.client.AnalyzePlan(ctx, &request) - if err != nil { - return nil, sparkerrors.WithType(fmt.Errorf("failed to call AnalyzePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError) + if se := sparkerrors.FromRPCError(err); se != nil { + return nil, sparkerrors.WithType(se, sparkerrors.ExecutionError) } return response, nil } @@ -126,6 +126,8 @@ func NewSparkExecutor(conn *grpc.ClientConn, md metadata.MD, sessionId string) S } } +// NewSparkExecutorFromClient creates a new SparkExecutor from an existing client and is mostly +// used in testing. func NewSparkExecutorFromClient(client proto.SparkConnectServiceClient, md metadata.MD, sessionId string) SparkExecutor { return &SparkExecutorImpl{ client: client, @@ -156,11 +158,15 @@ func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) { for { resp, err := c.responseStream.Recv() + // EOF is received when the last message has been processed and the stream + // finished normally. if err == io.EOF { break } - if err != nil { - return nil, nil, sparkerrors.WithType(fmt.Errorf("failed to receive plan execution response: %w", err), sparkerrors.ReadError) + + // If the error was not EOF, there might be another error. + if se := sparkerrors.FromRPCError(err); se != nil { + return nil, nil, sparkerrors.WithType(se, sparkerrors.ExecutionError) } // Process the message diff --git a/spark/sparkerrors/errors.go b/spark/sparkerrors/errors.go index 030db86..eecd34d 100644 --- a/spark/sparkerrors/errors.go +++ b/spark/sparkerrors/errors.go @@ -17,8 +17,13 @@ package sparkerrors import ( + "encoding/json" "errors" "fmt" + + "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type wrappedError struct { @@ -65,3 +70,71 @@ type InvalidServerSideSessionError struct { func (e InvalidServerSideSessionError) Error() string { return fmt.Sprintf("Received invalid session id %s, expected %s", e.ReceivedSessionId, e.OwnSessionId) } + +// SparkError represents an error that is returned from Spark itself. It captures details of the +// error that allows better understanding about the error. This allows us to check if the error +// can be retried or not. +type SparkError struct { + // SqlState is the SQL state of the error. + SqlState string + // ErrorClass is the class of the error. + ErrorClass string + // If set is typically the classname throwing the error on the Spark side. + Reason string + // Message is the human-readable message of the error. + Message string + // Code is the gRPC status code of the error. + Code codes.Code + // ErrorId is the unique id of the error. It can be used to fetch more details about + // the error using an additional RPC from the server. + ErrorId string + // Parameters are the parameters that are used to format the error message. + Parameters map[string]string + status *status.Status +} + +func (e SparkError) Error() string { + if e.Code == codes.Internal && e.SqlState != "" { + return fmt.Sprintf("[%s] %s. SQLSTATE: %s", e.ErrorClass, e.Message, e.SqlState) + } else { + return fmt.Sprintf("[%s] %s", e.Code.String(), e.Message) + } +} + +// FromRPCError converts a gRPC error to a SparkError. If the error is not a gRPC error, it will +// create a plain "UNKNOWN" GRPC status type. If no error was observed returns nil. +func FromRPCError(e error) *SparkError { + status := status.Convert(e) + // If there was no error, simply pass through. + if status == nil { + return nil + } + result := &SparkError{ + Message: status.Message(), + Code: status.Code(), + status: status, + } + + // Now lets, check if we can extract the error info from the details. + for _, d := range status.Details() { + switch info := d.(type) { + case *errdetails.ErrorInfo: + // Parse the parameters from the error details, but only parse them if + // they're present. + var params map[string]string + if v, ok := info.GetMetadata()["messageParameters"]; ok { + err := json.Unmarshal([]byte(v), ¶ms) + if err == nil { + // The message parameters is properly formatted JSON, if for some reason + // this is not the case, errors are ignored. + result.Parameters = params + } + } + result.SqlState = info.GetMetadata()["sqlState"] + result.ErrorClass = info.GetMetadata()["errorClass"] + result.ErrorId = info.GetMetadata()["errorId"] + result.Reason = info.Reason + } + } + return result +} diff --git a/spark/sparkerrors/errors_test.go b/spark/sparkerrors/errors_test.go index d12a1fb..184ec97 100644 --- a/spark/sparkerrors/errors_test.go +++ b/spark/sparkerrors/errors_test.go @@ -18,6 +18,10 @@ package sparkerrors import ( "testing" + "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "github.com/stretchr/testify/assert" ) @@ -30,3 +34,148 @@ func TestErrorStringContainsErrorType(t *testing.T) { err := WithType(assert.AnError, ConnectionError) assert.Contains(t, err.Error(), ConnectionError.Error()) } + +func TestGRPCErrorConversion(t *testing.T) { + err := status.Error(codes.Internal, "invalid argument") + se := FromRPCError(err) + assert.Equal(t, se.Code, codes.Internal) + assert.Equal(t, se.Message, "invalid argument") +} + +func TestNonGRPCErrorsAreConvertedAsWell(t *testing.T) { + err := assert.AnError + se := FromRPCError(err) + assert.Equal(t, se.Code, codes.Unknown) + assert.Equal(t, se.Message, assert.AnError.Error()) +} + +func TestErrorDetailsExtractionFromGRPCStatus(t *testing.T) { + status := status.New(codes.Internal, "AnalysisException") + status, _ = status.WithDetails(&errdetails.ErrorInfo{ + Reason: "AnalysisException", + Domain: "spark.sql", + Metadata: map[string]string{}, + }) + + err := status.Err() + se := FromRPCError(err) + assert.Equal(t, codes.Internal, se.Code) + assert.Equal(t, "AnalysisException", se.Message) + assert.Equal(t, "AnalysisException", se.Reason) +} + +func TestErrorDetailsWithSqlStateAndClass(t *testing.T) { + status := status.New(codes.Internal, "AnalysisException") + status, _ = status.WithDetails(&errdetails.ErrorInfo{ + Reason: "AnalysisException", + Domain: "spark.sql", + Metadata: map[string]string{ + "sqlState": "42000", + "errorClass": "ERROR_CLASS", + "errorId": "errorId", + "messageParameters": "", + }, + }) + + err := status.Err() + se := FromRPCError(err) + assert.Equal(t, codes.Internal, se.Code) + assert.Equal(t, "AnalysisException", se.Message) + assert.Equal(t, "AnalysisException", se.Reason) + assert.Equal(t, "42000", se.SqlState) + assert.Equal(t, "ERROR_CLASS", se.ErrorClass) + assert.Equal(t, "errorId", se.ErrorId) +} + +func TestErrorDetailsWithMessageParameterParsing(t *testing.T) { + type param struct { + TestName string + Input string + Expected map[string]string + } + + params := []param{ + {"empty input", "", nil}, + {"empty input", "{", nil}, + {"parse error", "{}", map[string]string{}}, + {"valid input", "{\"key\":\"value\"}", map[string]string{"key": "value"}}, + } + + for _, p := range params { + t.Run(p.TestName, func(t *testing.T) { + status := status.New(codes.Internal, "AnalysisException") + status, _ = status.WithDetails(&errdetails.ErrorInfo{ + Reason: "AnalysisException", + Domain: "spark.sql", + Metadata: map[string]string{ + "sqlState": "42000", + "errorClass": "ERROR_CLASS", + "errorId": "errorId", + "messageParameters": p.Input, + }, + }) + + err := status.Err() + se := FromRPCError(err) + assert.Equal(t, codes.Internal, se.Code) + assert.Equal(t, "AnalysisException", se.Message) + assert.Equal(t, "AnalysisException", se.Reason) + assert.Equal(t, "42000", se.SqlState) + assert.Equal(t, "ERROR_CLASS", se.ErrorClass) + assert.Equal(t, "errorId", se.ErrorId) + assert.Equal(t, p.Expected, se.Parameters) + }) + } +} + +func TestSparkError_Error(t *testing.T) { + type fields struct { + SqlState string + ErrorClass string + Reason string + Message string + Code codes.Code + ErrorId string + Parameters map[string]string + status *status.Status + } + tests := []struct { + name string + fields fields + want string + }{ + { + "UNKNOWN", + fields{ + Code: codes.Unknown, + Message: "Unknown error", + }, + "[Unknown] Unknown error", + }, + { + "Analysis Exception", + fields{ + SqlState: "42703", + ErrorClass: "UNRESOLVED_COLUMN.WITH_SUGGESTION", + Message: "A column, variable, or function parameter with name `id2` cannot be resolved. Did you mean one of the following? [`id`]", + Code: codes.Internal, + }, + "[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column, variable, or function parameter with name `id2` cannot be resolved. Did you mean one of the following? [`id`]. SQLSTATE: 42703", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := SparkError{ + SqlState: tt.fields.SqlState, + ErrorClass: tt.fields.ErrorClass, + Reason: tt.fields.Reason, + Message: tt.fields.Message, + Code: tt.fields.Code, + ErrorId: tt.fields.ErrorId, + Parameters: tt.fields.Parameters, + status: tt.fields.status, + } + assert.Equalf(t, tt.want, e.Error(), "Error()") + }) + } +}