diff --git a/.gitignore b/.gitignore index 22e98cf..b27a1b0 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,7 @@ .idea # vscode .vscode + +# Go workspace file +go.work +go.work.sum diff --git a/client.go b/client.go new file mode 100644 index 0000000..c26ec71 --- /dev/null +++ b/client.go @@ -0,0 +1,90 @@ +package websocket + +import ( + "bytes" + "errors" + "fmt" + "time" + + "github.com/cloudwego/hertz/pkg/protocol" +) + +// ErrBadHandshake is returned when the server response to opening handshake is +// invalid. +var ErrBadHandshake = errors.New("websocket: bad handshake") + +// ClientUpgrader is a helper for upgrading hertz http response to websocket conn. +// See ExampleClient for usage +type ClientUpgrader struct { + // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer + // size is zero, then buffers allocated by the HTTP server are used. The + // I/O buffer sizes do not limit the size of the messages that can be sent + // or received. + ReadBufferSize, WriteBufferSize int + + // WriteBufferPool is a pool of buffers for write operations. If the value + // is not set, then write buffers are allocated to the connection for the + // lifetime of the connection. + // + // A pool is most useful when the application has a modest volume of writes + // across a large number of connections. + // + // Applications should use a single pool for each unique value of + // WriteBufferSize. + WriteBufferPool BufferPool + + // EnableCompression specify if the server should attempt to negotiate per + // message compression (RFC 7692). Setting this value to true does not + // guarantee that compression will be supported. Currently only "no context + // takeover" modes are supported. + EnableCompression bool +} + +// PrepareRequest prepares request for websocket +// +// It adds headers for websocket, +// and it must be called BEFORE sending http request via cli.DoXXX +func (p *ClientUpgrader) PrepareRequest(req *protocol.Request) { + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Sec-WebSocket-Version", "13") + req.Header.Set("Sec-WebSocket-Key", generateChallengeKey()) + if p.EnableCompression { + req.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") + } +} + +// UpgradeResponse upgrades a response to websocket conn +// +// It returns Conn if success. ErrBadHandshake is returned if headers go wrong. +// This method must be called after PrepareRequest and (*.Client).DoXXX +func (p *ClientUpgrader) UpgradeResponse(req *protocol.Request, resp *protocol.Response) (*Conn, error) { + if resp.StatusCode() != 101 || + !tokenContainsValue(resp.Header.Get("Upgrade"), "websocket") || + !tokenContainsValue(resp.Header.Get("Connection"), "Upgrade") || + resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKeyBytes(req.Header.Peek("Sec-Websocket-Key")) { + return nil, ErrBadHandshake + } + + c, err := resp.Hijack() + if err != nil { + return nil, fmt.Errorf("Hijack response connection err: %w", err) + } + + c.SetDeadline(time.Time{}) + conn := newConn(c, false, p.ReadBufferSize, p.WriteBufferSize, p.WriteBufferPool, nil, nil) + + // can not use p.EnableCompression, always follow ext returned from server + compress := false + extensions := parseDataHeader(resp.Header.Peek("Sec-WebSocket-Extensions")) + for _, ext := range extensions { + if bytes.HasPrefix(ext, strPermessageDeflate) { + compress = true + } + } + if compress { + conn.newCompressionWriter = compressNoContextTakeover + conn.newDecompressionReader = decompressNoContextTakeover + } + return conn, nil +} diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..26f5f93 --- /dev/null +++ b/client_test.go @@ -0,0 +1,85 @@ +package websocket + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/app/client" + "github.com/cloudwego/hertz/pkg/app/server" + "github.com/cloudwego/hertz/pkg/network/standard" + "github.com/cloudwego/hertz/pkg/protocol" +) + +const ( + testaddr = "localhost:10012" + testpath = "/echo" +) + +func ExampleClient() { + runServer(testaddr) + time.Sleep(50 * time.Millisecond) // await server running + + c, err := client.NewClient(client.WithDialer(standard.NewDialer())) + if err != nil { + panic(err) + } + + req, resp := protocol.AcquireRequest(), protocol.AcquireResponse() + req.SetRequestURI("http://" + testaddr + testpath) + req.SetMethod("GET") + + u := &ClientUpgrader{} + u.PrepareRequest(req) + err = c.Do(context.Background(), req, resp) + if err != nil { + panic(err) + } + conn, err := u.UpgradeResponse(req, resp) + if err != nil { + panic(err) + } + + conn.WriteMessage(TextMessage, []byte("hello")) + m, b, err := conn.ReadMessage() + if err != nil { + panic(err) + } + fmt.Println(m, string(b)) + // Output: 1 hello +} + +func runServer(addr string) { + upgrader := HertzUpgrader{} // use default options + h := server.Default(server.WithHostPorts(addr)) + // https://github.com/cloudwego/hertz/issues/121 + h.NoHijackConnPool = true + h.GET(testpath, func(_ context.Context, c *app.RequestContext) { + err := upgrader.Upgrade(c, func(conn *Conn) { + for { + mt, message, err := conn.ReadMessage() + if err != nil { + log.Println("read:", err) + break + } + log.Printf("[server] recv: %v %s", mt, message) + err = conn.WriteMessage(mt, message) + if err != nil { + log.Println("write:", err) + break + } + } + }) + if err != nil { + log.Print("upgrade:", err) + return + } + }) + go func() { + if err := h.Run(); err != nil { + log.Fatal(err) + } + }() +} diff --git a/go.mod b/go.mod index e8f2f89..0cb2209 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,6 @@ module github.com/hertz-contrib/websocket go 1.16 require ( - github.com/bytedance/sonic v1.11.9 - github.com/cloudwego/hertz v0.9.1 + github.com/bytedance/sonic v1.12.0 + github.com/cloudwego/hertz v0.9.4-0.20241021100040-3477b0309b81 ) diff --git a/go.sum b/go.sum index e63d1e0..c674fab 100644 --- a/go.sum +++ b/go.sum @@ -1,25 +1,23 @@ github.com/bytedance/go-tagexpr/v2 v2.9.2 h1:QySJaAIQgOEDQBLS3x9BxOWrnhqu5sQ+f6HaZIxD39I= github.com/bytedance/go-tagexpr/v2 v2.9.2/go.mod h1:5qsx05dYOiUXOUgnQ7w3Oz8BYs2qtM/bJokdLb79wRM= -github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7 h1:PtwsQyQJGxf8iaPptPNaduEIu9BnrNms+pcRdHAxZaM= -github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q= -github.com/bytedance/mockey v1.2.1 h1:g84ngI88hz1DR4wZTL3yOuqlEcq67MretBfQUdXwrmw= -github.com/bytedance/mockey v1.2.1/go.mod h1:+Jm/fzWZAuhEDrPXVjDf/jLM2BlLXJkwk94zf2JZ3X4= -github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= -github.com/bytedance/sonic v1.8.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= -github.com/bytedance/sonic v1.11.9 h1:LFHENlIY/SLzDWverzdOvgMztTxcfcF+cqNsz9pK5zg= -github.com/bytedance/sonic v1.11.9/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= -github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= +github.com/bytedance/gopkg v0.0.0-20240507064146-197ded923ae3/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= +github.com/bytedance/gopkg v0.1.0 h1:aAxB7mm1qms4Wz4sp8e1AtKDOeFLtdqvGiUe7aonRJs= +github.com/bytedance/gopkg v0.1.0/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= +github.com/bytedance/mockey v1.2.12 h1:aeszOmGw8CPX8CRx1DZ/Glzb1yXvhjDh6jdFBNZjsU4= +github.com/bytedance/mockey v1.2.12/go.mod h1:3ZA4MQasmqC87Tw0w7Ygdy7eHIc2xgpZ8Pona5rsYIk= +github.com/bytedance/sonic v1.12.0 h1:YGPgxF9xzaCNvd/ZKdQ28yRovhfMFZQjuk6fKBzZ3ls= +github.com/bytedance/sonic v1.12.0/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= -github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= -github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/bytedance/sonic/loader v0.2.0 h1:zNprn+lsIP06C/IqCHs3gPQIvnvpKbbxyXQP1iU4kWM= +github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= -github.com/cloudwego/hertz v0.9.1 h1:+jK9A6MDNTUVy6q/zSOlhbnp1fFMiOaPIsq0jlOfjZE= -github.com/cloudwego/hertz v0.9.1/go.mod h1:cs8dH6unM4oaJ5k9m6pqbgLBPqakGWMG0+cthsxitsg= +github.com/cloudwego/hertz v0.9.4-0.20241021100040-3477b0309b81 h1:lrZ2nuRsR4M9KG1N+ihkict9Q2gzNwFxmId6NksKCAY= +github.com/cloudwego/hertz v0.9.4-0.20241021100040-3477b0309b81/go.mod h1:gGVUfJU/BOkJv/ZTzrw7FS7uy7171JeYIZvAyV3wS3o= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= -github.com/cloudwego/netpoll v0.6.0 h1:JRMkrA1o8k/4quxzg6Q1XM+zIhwZsyoWlq6ef+ht31U= -github.com/cloudwego/netpoll v0.6.0/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= +github.com/cloudwego/netpoll v0.6.2 h1:+KdILv5ATJU+222wNNXpHapYaBeRvvL8qhJyhcxRxrQ= +github.com/cloudwego/netpoll v0.6.2/go.mod h1:kaqvfZ70qd4T2WtIIpCOi5Cxyob8viEpzLhCrTrz3HM= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -74,12 +72,18 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VA golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220412211240-33da011f77ad h1:ntjMns5wyP/fN65tdBD4g8J5w8n015+iIIs9rtjXkY0= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= +golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/util.go b/util.go index 9993d98..035aa43 100644 --- a/util.go +++ b/util.go @@ -11,12 +11,21 @@ import ( "bytes" "crypto/sha1" "encoding/base64" + "encoding/binary" + "math/rand" "unicode/utf8" "unsafe" ) var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") +func generateChallengeKey() string { + b := make([]byte, 16) + binary.BigEndian.PutUint64(b, rand.Uint64()) + binary.BigEndian.PutUint64(b[8:], rand.Uint64()) + return base64.StdEncoding.EncodeToString(b) +} + // Token octets per RFC 2616. var isTokenOctet = [256]bool{ '!': true,