From a3678fef7ea0f8eae03dc1e309cb8125180ad29d Mon Sep 17 00:00:00 2001 From: Harsh Date: Sat, 20 Dec 2025 20:20:05 +0530 Subject: [PATCH] rpcperms: enable rich gRPC status errors in middleware API --- docs/release-notes/release-notes-0.21.0.md | 5 + rpcperms/middleware_handler.go | 75 +++++++++++- rpcperms/middleware_handler_test.go | 135 +++++++++++++++++++++ 3 files changed, 211 insertions(+), 4 deletions(-) diff --git a/docs/release-notes/release-notes-0.21.0.md b/docs/release-notes/release-notes-0.21.0.md index 30b7263de56..ee2cedde27e 100644 --- a/docs/release-notes/release-notes-0.21.0.md +++ b/docs/release-notes/release-notes-0.21.0.md @@ -71,6 +71,11 @@ ## RPC Updates +* [Enabled rich gRPC status error support in the middleware + API](https://github.com/lightningnetwork/lnd/pull/10458), allowing middleware + to inspect and modify full gRPC error details including error codes, not just + plain error strings. + ## lncli Updates ## Breaking Changes diff --git a/rpcperms/middleware_handler.go b/rpcperms/middleware_handler.go index fa1e8510f5e..2de36de6424 100644 --- a/rpcperms/middleware_handler.go +++ b/rpcperms/middleware_handler.go @@ -12,13 +12,26 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/macaroons" + spb "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" "gopkg.in/macaroon.v2" ) +const ( + // StatusTypeNameError is the type name used for plain error strings + // that are not gRPC status errors. + StatusTypeNameError = "error" + + // StatusTypeNameStatus is the fully qualified name of the + // google.rpc.Status proto message that is used to represent rich gRPC + // errors with proper error codes and details. + StatusTypeNameStatus = "google.rpc.Status" +) + var ( // ErrShuttingDown is the error that's returned when the server is // shutting down and a request cannot be served anymore. @@ -276,9 +289,20 @@ func (h *MiddlewareHandler) sendInterceptRequests(errChan chan error, // proto message? response.replace = true if requestInfo.request.IsError { - response.replacement = errors.New( - string(t.ReplacementSerialized), + // Check if the original error was a + // rich gRPC status error. If so, we + // need to parse the replacement as a + // Status proto and reconstruct the + // error with the proper gRPC code. + replacement, err := parseErrorReplacement( + requestInfo.request.ProtoTypeName, + t.ReplacementSerialized, ) + if err != nil { + response.err = err + break + } + response.replacement = replacement break } @@ -432,10 +456,27 @@ func NewMessageInterceptionRequest(ctx context.Context, req.ProtoTypeName = string(proto.MessageName(t)) case error: - req.ProtoSerialized = []byte(t.Error()) - req.ProtoTypeName = "error" req.IsError = true + // Check if the error is a gRPC status error. If so, we + // serialize the underlying Status proto to allow middleware to + // inspect and modify the full error details including the gRPC + // error code. + st, ok := status.FromError(t) + if ok { + req.ProtoSerialized, err = proto.Marshal(st.Proto()) + if err != nil { + return nil, fmt.Errorf("cannot marshal "+ + "status proto: %w", err) + } + req.ProtoTypeName = StatusTypeNameStatus + } else { + // Not a gRPC status error, fall back to plain error + // string serialization. + req.ProtoSerialized = []byte(t.Error()) + req.ProtoTypeName = StatusTypeNameError + } + default: return nil, fmt.Errorf("unsupported type for interception "+ "request: %v", m) @@ -582,6 +623,32 @@ func parseProto(typeName string, serialized []byte) (proto.Message, error) { return msg.Interface(), nil } +// parseErrorReplacement parses a replacement error from its serialized form. +// If the original error was a rich gRPC status error (indicated by the +// StatusTypeNameStatus type name), it will parse the replacement as a +// google.rpc.Status proto and reconstruct a proper gRPC status error. +// Otherwise, it treats the replacement as a plain error string. +func parseErrorReplacement(typeName string, serialized []byte) (error, error) { + // If the original error was a rich gRPC status, parse the replacement + // as a Status proto and reconstruct the error with the proper gRPC + // code and details. + if typeName == StatusTypeNameStatus { + // Unmarshal directly into the google.rpc.Status proto type. + statusProto := &spb.Status{} + if err := proto.Unmarshal(serialized, statusProto); err != nil { + return nil, fmt.Errorf("cannot parse status proto: %w", + err) + } + + // Convert the proto back to a gRPC status error. + st := status.FromProto(statusProto) + return st.Err(), nil + } + + // For plain error strings, just create a new error from the bytes. + return errors.New(string(serialized)), nil +} + // replaceProtoMsg replaces the given target message with the content of the // replacement message. func replaceProtoMsg(target interface{}, replacement interface{}) error { diff --git a/rpcperms/middleware_handler_test.go b/rpcperms/middleware_handler_test.go index aa8715bf053..6fb6cc3607f 100644 --- a/rpcperms/middleware_handler_test.go +++ b/rpcperms/middleware_handler_test.go @@ -1,11 +1,16 @@ package rpcperms import ( + "context" "encoding/json" + "errors" "testing" "github.com/lightningnetwork/lnd/lnrpc" "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" ) // TestReplaceProtoMsg makes sure the proto message replacement works as @@ -88,3 +93,133 @@ func jsonEqual(t *testing.T, expected, actual interface{}) { require.JSONEq(t, string(expectedJSON), string(actualJSON)) } + +// TestParseErrorReplacement tests that parseErrorReplacement correctly parses +// both plain error strings and rich gRPC status errors. +func TestParseErrorReplacement(t *testing.T) { + testCases := []struct { + name string + typeName string + serialized []byte + expectedErrMsg string + expectParseErr bool + }{{ + name: "plain error string", + typeName: StatusTypeNameError, + serialized: []byte("this is a plain error"), + expectedErrMsg: "this is a plain error", + }, { + name: "empty error string", + typeName: StatusTypeNameError, + serialized: []byte(""), + expectedErrMsg: "", + }, { + name: "invalid status proto", + typeName: StatusTypeNameStatus, + serialized: []byte("not a valid proto"), + expectParseErr: true, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(tt *testing.T) { + resultErr, parseErr := parseErrorReplacement( + tc.typeName, tc.serialized, + ) + + if tc.expectParseErr { + require.Error(tt, parseErr) + return + } + + require.NoError(tt, parseErr) + require.Equal(tt, tc.expectedErrMsg, resultErr.Error()) + }) + } +} + +// TestParseErrorReplacementWithStatus tests that parseErrorReplacement +// correctly handles gRPC status errors with proper error codes. +func TestParseErrorReplacementWithStatus(t *testing.T) { + // Create a gRPC status error with a specific code and message. + st := status.New(codes.NotFound, "resource not found") + statusProto := st.Proto() + + // Serialize the status proto. + serialized, err := proto.Marshal(statusProto) + require.NoError(t, err) + + // Parse it back. + resultErr, parseErr := parseErrorReplacement( + StatusTypeNameStatus, serialized, + ) + require.NoError(t, parseErr) + require.Error(t, resultErr) + + // Verify we can extract the status back. + resultStatus, ok := status.FromError(resultErr) + require.True(t, ok) + require.Equal(t, codes.NotFound, resultStatus.Code()) + require.Equal(t, "resource not found", resultStatus.Message()) +} + +// TestNewMessageInterceptionRequestWithStatusError tests that +// NewMessageInterceptionRequest correctly serializes gRPC status errors +// as google.rpc.Status protos instead of plain error strings. +func TestNewMessageInterceptionRequestWithStatusError(t *testing.T) { + testCases := []struct { + name string + err error + expectedTypeName string + isStatusError bool + }{{ + name: "plain error", + err: errors.New("this is a plain error"), + expectedTypeName: StatusTypeNameError, + isStatusError: false, + }, { + name: "gRPC status error", + err: status.Error(codes.NotFound, "resource not found"), + expectedTypeName: StatusTypeNameStatus, + isStatusError: true, + }, { + name: "gRPC status error with different code", + err: status.Error(codes.PermissionDenied, "access denied"), + expectedTypeName: StatusTypeNameStatus, + isStatusError: true, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(tt *testing.T) { + ctx := context.Background() + req, err := NewMessageInterceptionRequest( + ctx, TypeResponse, false, "/test/Method", + tc.err, + ) + require.NoError(tt, err) + require.True(tt, req.IsError) + require.Equal(tt, tc.expectedTypeName, req.ProtoTypeName) + + if tc.isStatusError { + // Verify we can parse the status back. + resultErr, parseErr := parseErrorReplacement( + req.ProtoTypeName, req.ProtoSerialized, + ) + require.NoError(tt, parseErr) + + // Verify the error code is preserved. + resultStatus, ok := status.FromError(resultErr) + require.True(tt, ok) + + originalStatus, _ := status.FromError(tc.err) + require.Equal( + tt, originalStatus.Code(), + resultStatus.Code(), + ) + require.Equal( + tt, originalStatus.Message(), + resultStatus.Message(), + ) + } + }) + } +}