Skip to content

Commit

Permalink
Adds possibility to define origin patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
ksysoev committed Apr 14, 2024
1 parent d1aa1c2 commit 704b015
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 2 deletions.
28 changes: 27 additions & 1 deletion channel/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@ type DefaultChannel struct {
connRegistry *DefaultConnectionRegistry
ctx context.Context
middlewares []Middlewere
config channelConfig
}

type channelConfig struct {
originPatterns []string
}

type Option func(*channelConfig)

// NewDefaultChannel creates new instance of DefaultChannel
// path - channel path
// dispatcher - dispatcher to use
Expand All @@ -26,12 +33,22 @@ type DefaultChannel struct {
func NewDefaultChannel(
path string,
dispatcher wasabi.Dispatcher,
opts ...Option,
) *DefaultChannel {
config := channelConfig{
originPatterns: []string{"*"},
}

for _, opt := range opts {
opt(&config)
}

return &DefaultChannel{
path: path,
disptacher: dispatcher,
connRegistry: NewDefaultConnectionRegistry(),
middlewares: make([]Middlewere, 0),
config: config,
}
}

Expand All @@ -51,7 +68,7 @@ func (c *DefaultChannel) wsConnectionHandler() http.Handler {
ctx := r.Context()

ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
OriginPatterns: c.config.originPatterns,
})

if err != nil {
Expand Down Expand Up @@ -88,3 +105,12 @@ func (c *DefaultChannel) setContext(next http.Handler) http.Handler {
next.ServeHTTP(w, r.WithContext(c.ctx))
})
}

// WithOriginPatterns sets the origin patterns for the channel.
// The origin patterns are used to validate the Origin header of the WebSocket handshake request.
// If the Origin header does not match any of the patterns, the connection is rejected.
func WithOriginPatterns(patterns ...string) Option {
return func(c *channelConfig) {
c.originPatterns = patterns
}
}
29 changes: 29 additions & 0 deletions channel/channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,32 @@ func TestDefaultChannel_SetContextMiddleware(t *testing.T) {
t.Errorf("Unexpected context: got %v, expected %v", ctx, channel.ctx)
}
}

func TestDefaultChannel_WithOriginPatterns(t *testing.T) {
path := "/test/path"
dispatcher := mocks.NewMockDispatcher(t)

channel := NewDefaultChannel(path, dispatcher)

if len(channel.config.originPatterns) != 1 {
t.Errorf("Unexpected number of origin patterns: got %d, expected %d", len(channel.config.originPatterns), 1)
}

if channel.config.originPatterns[0] != "*" {
t.Errorf("Unexpected to get default origin pattern: got %s, expected %s", channel.config.originPatterns[0], "*")
}

channel = NewDefaultChannel(path, dispatcher, WithOriginPatterns("test", "test2"))

if len(channel.config.originPatterns) != 2 {
t.Errorf("Unexpected number of origin patterns: got %d, expected %d", len(channel.config.originPatterns), 1)
}

if channel.config.originPatterns[0] != "test" {
t.Errorf("Unexpected to get default origin pattern: got %s, expected %s", channel.config.originPatterns[0], "test")
}

if channel.config.originPatterns[1] != "test2" {
t.Errorf("Unexpected to get default origin pattern: got %s, expected %s", channel.config.originPatterns[1], "test2")
}
}
2 changes: 1 addition & 1 deletion examples/http_backend/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func main() {
dispatcher.Use(ErrHandler)
dispatcher.Use(request.NewTrottlerMiddleware(10))

channel := channel.NewDefaultChannel("/", dispatcher)
channel := channel.NewDefaultChannel("/", dispatcher, channel.WithOriginPatterns("*"))

server := server.NewServer(Port)
server.AddChannel(channel)
Expand Down

0 comments on commit 704b015

Please sign in to comment.