From 704b015083c373b9890fb2f1f7f2d117fafb7641 Mon Sep 17 00:00:00 2001 From: Kirill Sysoev Date: Sun, 14 Apr 2024 10:38:10 +0800 Subject: [PATCH] Adds possibility to define origin patterns --- channel/channel.go | 28 +++++++++++++++++++++++++++- channel/channel_test.go | 29 +++++++++++++++++++++++++++++ examples/http_backend/main.go | 2 +- 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/channel/channel.go b/channel/channel.go index aeda62d..4bc5510 100644 --- a/channel/channel.go +++ b/channel/channel.go @@ -15,8 +15,15 @@ type DefaultChannel struct { connRegistry *DefaultConnectionRegistry ctx context.Context middlewares []Middlewere + config channelConfig } +type channelConfig struct { + originPatterns []string +} + +type Option func(*channelConfig) + // NewDefaultChannel creates new instance of DefaultChannel // path - channel path // dispatcher - dispatcher to use @@ -26,12 +33,22 @@ type DefaultChannel struct { func NewDefaultChannel( path string, dispatcher wasabi.Dispatcher, + opts ...Option, ) *DefaultChannel { + config := channelConfig{ + originPatterns: []string{"*"}, + } + + for _, opt := range opts { + opt(&config) + } + return &DefaultChannel{ path: path, disptacher: dispatcher, connRegistry: NewDefaultConnectionRegistry(), middlewares: make([]Middlewere, 0), + config: config, } } @@ -51,7 +68,7 @@ func (c *DefaultChannel) wsConnectionHandler() http.Handler { ctx := r.Context() ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - OriginPatterns: []string{"*"}, + OriginPatterns: c.config.originPatterns, }) if err != nil { @@ -88,3 +105,12 @@ func (c *DefaultChannel) setContext(next http.Handler) http.Handler { next.ServeHTTP(w, r.WithContext(c.ctx)) }) } + +// WithOriginPatterns sets the origin patterns for the channel. +// The origin patterns are used to validate the Origin header of the WebSocket handshake request. +// If the Origin header does not match any of the patterns, the connection is rejected. +func WithOriginPatterns(patterns ...string) Option { + return func(c *channelConfig) { + c.originPatterns = patterns + } +} diff --git a/channel/channel_test.go b/channel/channel_test.go index e87df37..8733910 100644 --- a/channel/channel_test.go +++ b/channel/channel_test.go @@ -151,3 +151,32 @@ func TestDefaultChannel_SetContextMiddleware(t *testing.T) { t.Errorf("Unexpected context: got %v, expected %v", ctx, channel.ctx) } } + +func TestDefaultChannel_WithOriginPatterns(t *testing.T) { + path := "/test/path" + dispatcher := mocks.NewMockDispatcher(t) + + channel := NewDefaultChannel(path, dispatcher) + + if len(channel.config.originPatterns) != 1 { + t.Errorf("Unexpected number of origin patterns: got %d, expected %d", len(channel.config.originPatterns), 1) + } + + if channel.config.originPatterns[0] != "*" { + t.Errorf("Unexpected to get default origin pattern: got %s, expected %s", channel.config.originPatterns[0], "*") + } + + channel = NewDefaultChannel(path, dispatcher, WithOriginPatterns("test", "test2")) + + if len(channel.config.originPatterns) != 2 { + t.Errorf("Unexpected number of origin patterns: got %d, expected %d", len(channel.config.originPatterns), 1) + } + + if channel.config.originPatterns[0] != "test" { + t.Errorf("Unexpected to get default origin pattern: got %s, expected %s", channel.config.originPatterns[0], "test") + } + + if channel.config.originPatterns[1] != "test2" { + t.Errorf("Unexpected to get default origin pattern: got %s, expected %s", channel.config.originPatterns[1], "test2") + } +} diff --git a/examples/http_backend/main.go b/examples/http_backend/main.go index f4f5bd9..cb7d497 100644 --- a/examples/http_backend/main.go +++ b/examples/http_backend/main.go @@ -40,7 +40,7 @@ func main() { dispatcher.Use(ErrHandler) dispatcher.Use(request.NewTrottlerMiddleware(10)) - channel := channel.NewDefaultChannel("/", dispatcher) + channel := channel.NewDefaultChannel("/", dispatcher, channel.WithOriginPatterns("*")) server := server.NewServer(Port) server.AddChannel(channel)