diff --git a/go.mod b/go.mod index 34ed75164..f523e8465 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/FZambia/statik v0.1.2-0.20180217151304-b9f012bb2a1b github.com/FZambia/tarantool v0.3.1 github.com/FZambia/viper-lite v0.0.0-20220110144934-1899f66c7d0e - github.com/centrifugal/centrifuge v0.33.4 + github.com/centrifugal/centrifuge v0.33.5-0.20241026095149-106d43f21426 github.com/centrifugal/protocol v0.13.4 github.com/cristalhq/jwt/v5 v5.4.0 github.com/gobwas/glob v0.2.3 @@ -20,7 +20,7 @@ require ( github.com/mattn/go-isatty v0.0.20 github.com/mitchellh/mapstructure v1.5.0 github.com/nats-io/nats.go v1.37.0 - github.com/prometheus/client_golang v1.20.4 + github.com/prometheus/client_golang v1.20.5 github.com/quic-go/quic-go v0.47.0 github.com/quic-go/webtransport-go v0.8.0 github.com/rakutentech/jwk-go v1.1.3 @@ -92,7 +92,7 @@ require ( github.com/prometheus/common v0.60.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect github.com/quic-go/qpack v0.5.1 // indirect - github.com/redis/rueidis v1.0.47 // indirect + github.com/redis/rueidis v1.0.48 // indirect github.com/segmentio/asm v1.2.0 // indirect github.com/segmentio/encoding v0.4.0 // indirect github.com/spf13/cast v1.4.1 // indirect diff --git a/go.sum b/go.sum index 18ad310ac..4a22ded1b 100644 --- a/go.sum +++ b/go.sum @@ -10,8 +10,8 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= -github.com/centrifugal/centrifuge v0.33.4 h1:h0b1X5DdKPkizJ1dJEo5u83a5baxyUHpDLhfSODHqbY= -github.com/centrifugal/centrifuge v0.33.4/go.mod h1:83bjiDCVcoWrXXjFibG9vS0fJ+aUpSA3kD2RLlvl1RE= +github.com/centrifugal/centrifuge v0.33.5-0.20241026095149-106d43f21426 h1:g5zZaCr/BybYgq8Nqrnrvqvb3jGGO/Dloil3cFGzzbg= +github.com/centrifugal/centrifuge v0.33.5-0.20241026095149-106d43f21426/go.mod h1:Ck+7H3eVwoeyabKcj3L55oSunaORIOGPAIVB5xrQyGU= github.com/centrifugal/protocol v0.13.4 h1:I0YxXtFNfn/ndDIZp5RkkqQcSSNH7DNPUbXKYtJXDzs= github.com/centrifugal/protocol v0.13.4/go.mod h1:7V5vI30VcoxJe4UD87xi7bOsvI0bmEhvbQuMjrFM2L4= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= @@ -144,8 +144,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U= -github.com/prometheus/client_golang v1.20.4 h1:Tgh3Yr67PaOv/uTqloMsCEdeuFTatm5zIq5+qNN23vI= -github.com/prometheus/client_golang v1.20.4/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= +github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y= +github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= github.com/prometheus/common v0.60.0 h1:+V9PAREWNvJMAuJ1x1BaWl9dewMW4YrHZQbx0sJNllA= @@ -160,8 +160,8 @@ github.com/quic-go/webtransport-go v0.8.0 h1:HxSrwun11U+LlmwpgM1kEqIqH90IT4N8auv github.com/quic-go/webtransport-go v0.8.0/go.mod h1:N99tjprW432Ut5ONql/aUhSLT0YVSlwHohQsuac9WaM= github.com/rakutentech/jwk-go v1.1.3 h1:PiLwepKyUaW+QFG3ki78DIO2+b4IVK3nMhlxM70zrQ4= github.com/rakutentech/jwk-go v1.1.3/go.mod h1:LtzSv4/+Iti1nnNeVQiP6l5cI74GBStbhyXCYvgPZFk= -github.com/redis/rueidis v1.0.47 h1:41UdeXOo4eJuW+cfpUJuLtVGyO0QJY3A2rEYgJWlfHs= -github.com/redis/rueidis v1.0.47/go.mod h1:by+34b0cFXndxtYmPAHpoTHO5NkosDlBvhexoTURIxM= +github.com/redis/rueidis v1.0.48 h1:ggZHjEtc/echUmPkGTfssRisnc3p/mIUEwrpbNsZ1mQ= +github.com/redis/rueidis v1.0.48/go.mod h1:by+34b0cFXndxtYmPAHpoTHO5NkosDlBvhexoTURIxM= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= @@ -245,8 +245,6 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= -golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= @@ -283,12 +281,8 @@ golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= -golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= -golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= -golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -304,8 +298,6 @@ google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 h1: google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:qpvKtACPCQhAdu3PyQgV4l3LMXZEtft7y8QcarRsp9I= google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ= google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= -google.golang.org/grpc v1.66.1 h1:hO5qAXR19+/Z44hmvIM4dQFMSYX9XcWsByfoxutBpAM= -google.golang.org/grpc v1.66.1/go.mod h1:s3/l6xSSCURdVfAnL+TqCNMyTDAGN6+lZeVxnZR128Y= google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E= google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= diff --git a/internal/proxy/connect_handler.go b/internal/proxy/connect_handler.go index e92ce44e5..b80791190 100644 --- a/internal/proxy/connect_handler.go +++ b/internal/proxy/connect_handler.go @@ -77,7 +77,7 @@ func (h *ConnectHandler) Handle(node *centrifuge.Node) ConnectingHandlerFunc { h.histogram.Observe(duration) h.errors.Inc() node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "error proxying connect", map[string]any{"client": e.ClientID, "error": err.Error()})) - return centrifuge.ConnectReply{}, ConnectExtra{}, centrifuge.ErrorInternal + return centrifuge.ConnectReply{}, ConnectExtra{}, err } h.summary.Observe(duration) h.histogram.Observe(duration) diff --git a/internal/proxy/connect_handler_test.go b/internal/proxy/connect_handler_test.go index 4a92ac59b..4a21800e7 100644 --- a/internal/proxy/connect_handler_test.go +++ b/internal/proxy/connect_handler_test.go @@ -15,6 +15,8 @@ import ( "github.com/centrifugal/centrifuge" "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type grpcConnHandleTestCase struct { @@ -41,6 +43,9 @@ func getTestHttpProxy(commonProxyTestCase *tools.CommonHTTPProxyTestCase, endpoi StaticHttpHeaders: map[string]string{ "X-Test": "test", }, + HttpStatusTransforms: []HttpStatusToCodeTransform{ + {StatusCode: 404, ToDisconnect: TransformDisconnect{Code: 4504, Reason: "not found"}}, + }, } } @@ -227,7 +232,13 @@ func TestHandleConnectWithoutProxyServerStart(t *testing.T) { cases := newConnHandleTestCases(httpTestCase, grpcTestCase) for _, c := range cases { reply, err := c.invokeHandle(context.Background()) - require.ErrorIs(t, centrifuge.ErrorInternal, err, c.protocol) + if c.protocol == "grpc" { + st, ok := status.FromError(err) + require.True(t, ok, c.protocol) + require.Equal(t, codes.Unavailable, st.Code(), c.protocol) + } else { + require.Error(t, err, c.protocol) + } require.Equal(t, centrifuge.ConnectReply{}, reply, c.protocol) } } @@ -326,3 +337,32 @@ func TestHandleConnectWithSubscriptionError(t *testing.T) { require.Equal(t, centrifuge.ConnectReply{}, reply, c.protocol) } } + +func TestHandleConnectWithHTTPCodeTransform(t *testing.T) { + grpcTestCase := newConnHandleGRPCTestCase(context.Background(), newProxyGRPCTestServer("http status code transform", proxyGRPCTestServerOptions{})) + defer grpcTestCase.Teardown() + + httpTestCase := newConnHandleHTTPTestCase(context.Background(), "/proxy") + httpTestCase.Mux.HandleFunc("/proxy", func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{}`)) + }) + defer httpTestCase.Teardown() + + cases := newConnHandleTestCases(httpTestCase, grpcTestCase) + for _, c := range cases { + if c.protocol == "grpc" { + continue // Transforms not supported. + } + + expectedErr := centrifuge.Disconnect{ + Code: 4504, + Reason: "not found", + } + + reply, err := c.invokeHandle(context.Background()) + require.NotNil(t, err, c.protocol) + require.Equal(t, expectedErr.Error(), err.Error(), c.protocol) + require.Equal(t, centrifuge.ConnectReply{}, reply, c.protocol) + } +} diff --git a/internal/proxy/connect_http.go b/internal/proxy/connect_http.go index 47ecbb56b..37b612766 100644 --- a/internal/proxy/connect_http.go +++ b/internal/proxy/connect_http.go @@ -41,6 +41,13 @@ func (p *HTTPConnectProxy) ProxyConnect(ctx context.Context, req *proxyproto.Con } respData, err := p.httpCaller.CallHTTP(ctx, p.config.Endpoint, httpRequestHeaders(ctx, p.config), data) if err != nil { + protocolError, protocolDisconnect := transformHTTPStatusError(err, p.config.HttpStatusTransforms) + if protocolError != nil || protocolDisconnect != nil { + return &proxyproto.ConnectResponse{ + Error: protocolError, + Disconnect: protocolDisconnect, + }, nil + } return nil, err } return httpDecoder.DecodeConnectResponse(respData) diff --git a/internal/proxy/http.go b/internal/proxy/http.go index 6cefb8458..557bcfbfb 100644 --- a/internal/proxy/http.go +++ b/internal/proxy/http.go @@ -3,6 +3,7 @@ package proxy import ( "bytes" "context" + "errors" "fmt" "io" "net/http" @@ -107,3 +108,31 @@ func stringInSlice(a string, list []string) bool { } return false } + +func transformHTTPStatusError(err error, transforms []HttpStatusToCodeTransform) (*proxyproto.Error, *proxyproto.Disconnect) { + if len(transforms) == 0 { + return nil, nil + } + var statusErr *statusCodeError + if !errors.As(err, &statusErr) { + return nil, nil + } + for _, t := range transforms { + if t.StatusCode == statusErr.Code { + if t.ToError.Code > 0 { + return &proxyproto.Error{ + Code: t.ToError.Code, + Message: t.ToError.Message, + Temporary: t.ToError.Temporary, + }, nil + } + if t.ToDisconnect.Code > 0 { + return nil, &proxyproto.Disconnect{ + Code: t.ToDisconnect.Code, + Reason: t.ToDisconnect.Reason, + } + } + } + } + return nil, nil +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 73935b7b5..d88b57a9d 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -9,6 +9,23 @@ import ( "github.com/centrifugal/centrifugo/v5/internal/tools" ) +type TransformError struct { + Code uint32 `mapstructure:"code" json:"code"` + Message string `mapstructure:"message" json:"message"` + Temporary bool `mapstructure:"temporary" json:"temporary"` +} + +type TransformDisconnect struct { + Code uint32 `mapstructure:"code" json:"code"` + Reason string `mapstructure:"reason" json:"reason"` +} + +type HttpStatusToCodeTransform struct { + StatusCode int `mapstructure:"status_code" json:"status_code"` + ToError TransformError `mapstructure:"to_error" json:"to_error"` + ToDisconnect TransformDisconnect `mapstructure:"to_disconnect" json:"to_disconnect"` +} + // Config for proxy. type Config struct { // Name is a unique name of proxy to reference. @@ -20,7 +37,9 @@ type Config struct { // HTTPHeaders is a list of HTTP headers to proxy. No headers used by proxy by default. // If GRPC proxy is used then request HTTP headers set to outgoing request metadata. - HttpHeaders []string `mapstructure:"http_headers" json:"http_headers,omitempty"` + HttpHeaders []string `mapstructure:"http_headers" json:"http_headers,omitempty"` + HttpStatusTransforms []HttpStatusToCodeTransform `mapstructure:"http_status_to_code_transforms" json:"http_status_to_code_transforms,omitempty"` + // GRPCMetadata is a list of GRPC metadata keys to proxy. No meta keys used by proxy by // default. If HTTP proxy is used then these keys become outgoing request HTTP headers. GrpcMetadata []string `mapstructure:"grpc_metadata" json:"grpc_metadata,omitempty"` diff --git a/internal/proxy/publish_http.go b/internal/proxy/publish_http.go index 0a27e6ced..4186f09ac 100644 --- a/internal/proxy/publish_http.go +++ b/internal/proxy/publish_http.go @@ -44,6 +44,13 @@ func (p *HTTPPublishProxy) ProxyPublish(ctx context.Context, req *proxyproto.Pub } respData, err := p.httpCaller.CallHTTP(ctx, p.config.Endpoint, httpRequestHeaders(ctx, p.config), data) if err != nil { + protocolError, protocolDisconnect := transformHTTPStatusError(err, p.config.HttpStatusTransforms) + if protocolError != nil || protocolDisconnect != nil { + return &proxyproto.PublishResponse{ + Error: protocolError, + Disconnect: protocolDisconnect, + }, nil + } return nil, err } return httpDecoder.DecodePublishResponse(respData) diff --git a/internal/proxy/refresh_http.go b/internal/proxy/refresh_http.go index b5741e672..c3621692d 100644 --- a/internal/proxy/refresh_http.go +++ b/internal/proxy/refresh_http.go @@ -40,6 +40,13 @@ func (p *HTTPRefreshProxy) ProxyRefresh(ctx context.Context, req *proxyproto.Ref if err != nil { return nil, err } + protocolError, protocolDisconnect := transformHTTPStatusError(err, p.config.HttpStatusTransforms) + if protocolError != nil || protocolDisconnect != nil { + return &proxyproto.RefreshResponse{ + Error: protocolError, + Disconnect: protocolDisconnect, + }, nil + } return httpDecoder.DecodeRefreshResponse(respData) } diff --git a/internal/proxy/rpc_http.go b/internal/proxy/rpc_http.go index d78a37ed5..4303f7823 100644 --- a/internal/proxy/rpc_http.go +++ b/internal/proxy/rpc_http.go @@ -31,6 +31,13 @@ func (p *HTTPRPCProxy) ProxyRPC(ctx context.Context, req *proxyproto.RPCRequest) } respData, err := p.httpCaller.CallHTTP(ctx, p.config.Endpoint, httpRequestHeaders(ctx, p.config), data) if err != nil { + protocolError, protocolDisconnect := transformHTTPStatusError(err, p.config.HttpStatusTransforms) + if protocolError != nil || protocolDisconnect != nil { + return &proxyproto.RPCResponse{ + Error: protocolError, + Disconnect: protocolDisconnect, + }, nil + } return nil, err } return httpDecoder.DecodeRPCResponse(respData) diff --git a/internal/proxy/sub_refresh_http.go b/internal/proxy/sub_refresh_http.go index 4de1640c0..4d7fcf9cf 100644 --- a/internal/proxy/sub_refresh_http.go +++ b/internal/proxy/sub_refresh_http.go @@ -39,6 +39,13 @@ func (p *HTTPSubRefreshProxy) ProxySubRefresh(ctx context.Context, req *proxypro } respData, err := p.httpCaller.CallHTTP(ctx, p.config.Endpoint, httpRequestHeaders(ctx, p.config), data) if err != nil { + protocolError, protocolDisconnect := transformHTTPStatusError(err, p.config.HttpStatusTransforms) + if protocolError != nil || protocolDisconnect != nil { + return &proxyproto.SubRefreshResponse{ + Error: protocolError, + Disconnect: protocolDisconnect, + }, nil + } return nil, err } return httpDecoder.DecodeSubRefreshResponse(respData) diff --git a/internal/proxy/subscribe_http.go b/internal/proxy/subscribe_http.go index 6cf0276f2..3fa09b771 100644 --- a/internal/proxy/subscribe_http.go +++ b/internal/proxy/subscribe_http.go @@ -31,6 +31,13 @@ func (p *HTTPSubscribeProxy) ProxySubscribe(ctx context.Context, req *proxyproto } respData, err := p.httpCaller.CallHTTP(ctx, p.config.Endpoint, httpRequestHeaders(ctx, p.config), data) if err != nil { + protocolError, protocolDisconnect := transformHTTPStatusError(err, p.config.HttpStatusTransforms) + if protocolError != nil || protocolDisconnect != nil { + return &proxyproto.SubscribeResponse{ + Error: protocolError, + Disconnect: protocolDisconnect, + }, nil + } return nil, err } return httpDecoder.DecodeSubscribeResponse(respData) diff --git a/internal/tools/ascii.go b/internal/tools/ascii.go new file mode 100644 index 000000000..4944411de --- /dev/null +++ b/internal/tools/ascii.go @@ -0,0 +1,12 @@ +package tools + +import "unicode" + +func IsASCII(s string) bool { + for _, c := range s { + if c > unicode.MaxASCII { + return false + } + } + return true +} diff --git a/internal/tools/code_translate.go b/internal/tools/code_translate.go new file mode 100644 index 000000000..138db4cd6 --- /dev/null +++ b/internal/tools/code_translate.go @@ -0,0 +1,54 @@ +package tools + +import ( + "net/http" + + "github.com/centrifugal/centrifuge" +) + +type ConnectCodeToHTTPStatus struct { + Enabled bool `mapstructure:"enabled" json:"enabled"` + Transforms []ConnectCodeToHTTPStatusTransform `mapstructure:"transforms" json:"transforms"` +} + +type ConnectCodeToHTTPStatusTransform struct { + Code uint32 `mapstructure:"code" json:"code"` + ToResponse TransformedConnectErrorHttpResponse `mapstructure:"to_response" json:"to_response"` +} + +type TransformedConnectErrorHttpResponse struct { + Status int `mapstructure:"status_code" json:"status_code"` + Body string `mapstructure:"body" json:"body"` +} + +func ConnectErrorToToHTTPResponse(err error, transforms []ConnectCodeToHTTPStatusTransform) (TransformedConnectErrorHttpResponse, bool) { + var code uint32 + var body string + switch t := err.(type) { + case *centrifuge.Disconnect: + code = t.Code + body = t.Reason + case centrifuge.Disconnect: + code = t.Code + body = t.Reason + case *centrifuge.Error: + code = t.Code + body = t.Message + default: + } + if code > 0 { + for _, t := range transforms { + if t.Code != code { + continue + } + if t.ToResponse.Body == "" { + t.ToResponse.Body = body + } + return t.ToResponse, true + } + } + return TransformedConnectErrorHttpResponse{ + Status: http.StatusInternalServerError, + Body: http.StatusText(http.StatusInternalServerError), + }, false +} diff --git a/internal/unihttpstream/config.go b/internal/unihttpstream/config.go index 49528c03b..30ff3bccb 100644 --- a/internal/unihttpstream/config.go +++ b/internal/unihttpstream/config.go @@ -1,10 +1,14 @@ package unihttpstream -import "github.com/centrifugal/centrifuge" +import ( + "github.com/centrifugal/centrifugo/v5/internal/tools" + + "github.com/centrifugal/centrifuge" +) type Config struct { // MaxRequestBodySize limits request body size. - MaxRequestBodySize int - + MaxRequestBodySize int + ConnectCodeToHTTPStatus tools.ConnectCodeToHTTPStatus centrifuge.PingPongConfig } diff --git a/internal/unihttpstream/handler.go b/internal/unihttpstream/handler.go index 66b72f183..bcd246565 100644 --- a/internal/unihttpstream/handler.go +++ b/internal/unihttpstream/handler.go @@ -6,6 +6,8 @@ import ( "net/http" "time" + "github.com/centrifugal/centrifugo/v5/internal/tools" + "github.com/centrifugal/centrifuge" "github.com/centrifugal/protocol" ) @@ -71,24 +73,6 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { }(time.Now()) } - if r.ProtoMajor == 1 { - // An endpoint MUST NOT generate an HTTP/2 message containing connection-specific header fields. - // Source: RFC7540. - w.Header().Set("Connection", "keep-alive") - } - w.Header().Set("X-Accel-Buffering", "no") - w.Header().Set("Cache-Control", "private, no-cache, no-store, must-revalidate, max-age=0") - w.Header().Set("Pragma", "no-cache") - w.Header().Set("Expire", "0") - w.WriteHeader(http.StatusOK) - - _, ok := w.(http.Flusher) - if !ok { - return - } - - rc := http.NewResponseController(w) - connectRequest := centrifuge.ConnectRequest{ Token: req.Token, Data: req.Data, @@ -107,7 +91,40 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { connectRequest.Subs = subs } - c.Connect(connectRequest) + if h.config.ConnectCodeToHTTPStatus.Enabled { + err = c.ConnectNoErrorToDisconnect(connectRequest) + if err != nil { + resp, ok := tools.ConnectErrorToToHTTPResponse(err, h.config.ConnectCodeToHTTPStatus.Transforms) + if ok { + w.WriteHeader(resp.Status) + _, _ = w.Write([]byte(resp.Body)) + return + } + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(http.StatusText(http.StatusInternalServerError))) + return + } + } else { + c.Connect(connectRequest) + } + + if r.ProtoMajor == 1 { + // An endpoint MUST NOT generate an HTTP/2 message containing connection-specific header fields. + // Source: RFC7540. + w.Header().Set("Connection", "keep-alive") + } + w.Header().Set("X-Accel-Buffering", "no") + w.Header().Set("Cache-Control", "private, no-cache, no-store, must-revalidate, max-age=0") + w.Header().Set("Pragma", "no-cache") + w.Header().Set("Expire", "0") + w.WriteHeader(http.StatusOK) + + _, ok := w.(http.Flusher) + if !ok { + return + } + + rc := http.NewResponseController(w) for { select { diff --git a/internal/unisse/config.go b/internal/unisse/config.go index d7478fba6..19184a2fc 100644 --- a/internal/unisse/config.go +++ b/internal/unisse/config.go @@ -1,10 +1,13 @@ package unisse -import "github.com/centrifugal/centrifuge" +import ( + "github.com/centrifugal/centrifuge" + "github.com/centrifugal/centrifugo/v5/internal/tools" +) type Config struct { // MaxRequestBodySize for POST requests when used. - MaxRequestBodySize int - + MaxRequestBodySize int + ConnectCodeToHTTPStatus tools.ConnectCodeToHTTPStatus centrifuge.PingPongConfig } diff --git a/internal/unisse/handler.go b/internal/unisse/handler.go index 27519c697..803e1d693 100644 --- a/internal/unisse/handler.go +++ b/internal/unisse/handler.go @@ -6,6 +6,8 @@ import ( "net/http" "time" + "github.com/centrifugal/centrifugo/v5/internal/tools" + "github.com/centrifugal/centrifuge" "github.com/centrifugal/protocol" ) @@ -81,6 +83,41 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { }(time.Now()) } + connectRequest := centrifuge.ConnectRequest{ + Token: req.Token, + Data: req.Data, + Name: req.Name, + Version: req.Version, + } + if req.Subs != nil { + subs := make(map[string]centrifuge.SubscribeRequest, len(req.Subs)) + for k, v := range req.Subs { + subs[k] = centrifuge.SubscribeRequest{ + Recover: v.Recover, + Offset: v.Offset, + Epoch: v.Epoch, + } + } + connectRequest.Subs = subs + } + + if h.config.ConnectCodeToHTTPStatus.Enabled { + err = c.ConnectNoErrorToDisconnect(connectRequest) + if err != nil { + resp, ok := tools.ConnectErrorToToHTTPResponse(err, h.config.ConnectCodeToHTTPStatus.Transforms) + if ok { + w.WriteHeader(resp.Status) + _, _ = w.Write([]byte(resp.Body)) + return + } + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(http.StatusText(http.StatusInternalServerError))) + return + } + } else { + c.Connect(connectRequest) + } + if r.ProtoMajor == 1 { // An endpoint MUST NOT generate an HTTP/2 message containing connection-specific header fields. // Source: RFC7540. @@ -106,26 +143,6 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } _ = rc.Flush() - connectRequest := centrifuge.ConnectRequest{ - Token: req.Token, - Data: req.Data, - Name: req.Name, - Version: req.Version, - } - if req.Subs != nil { - subs := make(map[string]centrifuge.SubscribeRequest, len(req.Subs)) - for k, v := range req.Subs { - subs[k] = centrifuge.SubscribeRequest{ - Recover: v.Recover, - Offset: v.Offset, - Epoch: v.Epoch, - } - } - connectRequest.Subs = subs - } - - c.Connect(connectRequest) - for { select { case <-r.Context().Done(): diff --git a/main.go b/main.go index 21a573e36..6fd79e80a 100644 --- a/main.go +++ b/main.go @@ -243,6 +243,11 @@ var defaults = map[string]any{ "uni_sse": false, "uni_http_stream": false, + "uni_sse_connect_code_to_http_response.enabled": false, + "uni_sse_connect_code_to_http_response.transforms": []any{}, + "uni_http_stream_connect_code_to_http_response.enabled": false, + "uni_http_stream_connect_code_to_http_response.transforms": []any{}, + "log_level": "info", "log_file": "", @@ -291,14 +296,15 @@ var defaults = map[string]any{ "proxy_sub_refresh_timeout": time.Second, "proxy_subscribe_stream_timeout": time.Second, - "proxy_grpc_metadata": []string{}, - "proxy_http_headers": []string{}, - "proxy_static_http_headers": map[string]string{}, - "proxy_binary_encoding": false, - "proxy_include_connection_meta": false, - "proxy_grpc_cert_file": "", - "proxy_grpc_compression": false, - "proxy_grpc_tls": tools.TLSConfig{}, + "proxy_http_status_code_transforms": []any{}, + "proxy_grpc_metadata": []string{}, + "proxy_http_headers": []string{}, + "proxy_static_http_headers": map[string]string{}, + "proxy_binary_encoding": false, + "proxy_include_connection_meta": false, + "proxy_grpc_cert_file": "", + "proxy_grpc_compression": false, + "proxy_grpc_tls": tools.TLSConfig{}, "tarantool_mode": "standalone", "tarantool_address": "tcp://127.0.0.1:3301", @@ -2065,6 +2071,36 @@ func proxyMapConfig() (*client.ProxyMap, bool) { } proxyConfig.StaticHttpHeaders = staticHttpHeaders + var httpStatusTransforms []proxy.HttpStatusToCodeTransform + if v.IsSet("proxy_http_status_code_transforms") { + tools.DecodeSlice(v, &httpStatusTransforms, "proxy_http_status_code_transforms") + } + for _, transform := range httpStatusTransforms { + if transform.StatusCode == 0 { + log.Fatal().Msg("status should be set in proxy_http_status_code_transforms item") + } + if transform.ToDisconnect.Code == 0 && transform.ToError.Code == 0 { + log.Fatal().Msg("no error or disconnect code set in proxy_http_status_code_transforms item") + } + if transform.ToDisconnect.Code > 0 && transform.ToError.Code > 0 { + log.Fatal().Msg("only error or disconnect code can be set in proxy_http_status_code_transforms item, but not both") + } + if !tools.IsASCII(transform.ToDisconnect.Reason) { + log.Fatal().Msg("proxy_http_status_code_transforms item disconnect reason must be ASCII") + } + if !tools.IsASCII(transform.ToError.Message) { + log.Fatal().Msg("proxy_http_status_code_transforms item error message must be ASCII") + } + const reasonOrMessageMaxLength = 123 // limit comes from WebSocket close reason length limit. See https://datatracker.ietf.org/doc/html/rfc6455. + if len(transform.ToError.Message) > reasonOrMessageMaxLength { + log.Fatal().Msgf("proxy_http_status_code_transforms item error message can be up to %d characters long", reasonOrMessageMaxLength) + } + if len(transform.ToDisconnect.Reason) > reasonOrMessageMaxLength { + log.Fatal().Msgf("proxy_http_status_code_transforms item disconnect reason can be up to %d characters long", reasonOrMessageMaxLength) + } + } + proxyConfig.HttpStatusTransforms = httpStatusTransforms + connectEndpoint := v.GetString("proxy_connect_endpoint") connectTimeout := GetDuration("proxy_connect_timeout") refreshEndpoint := v.GetString("proxy_refresh_endpoint") @@ -2529,16 +2565,34 @@ func uniWebsocketHandlerConfig() uniws.Config { } func uniSSEHandlerConfig() unisse.Config { + connectCodeToHttpStatusEnabled := viper.GetBool("uni_sse_connect_code_to_http_response.enabled") + var connectCodeToHTTPStatusTransforms []tools.ConnectCodeToHTTPStatusTransform + if viper.IsSet("uni_sse_connect_code_to_http_response.transforms") { + tools.DecodeSlice(viper.GetViper(), &connectCodeToHTTPStatusTransforms, "uni_sse_connect_code_to_http_response.transforms") + } return unisse.Config{ MaxRequestBodySize: viper.GetInt("uni_sse_max_request_body_size"), PingPongConfig: getPingPongConfig(), + ConnectCodeToHTTPStatus: tools.ConnectCodeToHTTPStatus{ + Enabled: connectCodeToHttpStatusEnabled, + Transforms: connectCodeToHTTPStatusTransforms, + }, } } func uniStreamHandlerConfig() unihttpstream.Config { + connectCodeToHttpStatusEnabled := viper.GetBool("uni_http_stream_connect_code_to_http_response.enabled") + var connectCodeToHTTPStatusTransforms []tools.ConnectCodeToHTTPStatusTransform + if viper.IsSet("uni_http_stream_connect_code_to_http_response.transforms") { + tools.DecodeSlice(viper.GetViper(), &connectCodeToHTTPStatusTransforms, "uni_http_stream_connect_code_to_http_response.transforms") + } return unihttpstream.Config{ MaxRequestBodySize: viper.GetInt("uni_http_stream_max_request_body_size"), PingPongConfig: getPingPongConfig(), + ConnectCodeToHTTPStatus: tools.ConnectCodeToHTTPStatus{ + Enabled: connectCodeToHttpStatusEnabled, + Transforms: connectCodeToHTTPStatusTransforms, + }, } }