diff --git a/application.go b/application.go index c948b32..f7ec0f6 100644 --- a/application.go +++ b/application.go @@ -5,7 +5,8 @@ import ( "errors" "io" "net" - "runtime" + "net/url" + "strings" "sync" "sync/atomic" "time" @@ -44,8 +45,8 @@ type server struct { } } -// NewApplication returns a net application with listener -func NewApplication(listener net.Listener, handleFunc func(IOSession, interface{}, uint64) error, opts ...AppOption) (NetApplication, error) { +// NewApplicationWithListener returns a net application with listener +func NewApplicationWithListener(listener net.Listener, handleFunc func(IOSession, interface{}, uint64) error, opts ...AppOption) (NetApplication, error) { s := &server{ listener: listener, handleFunc: handleFunc, @@ -70,17 +71,23 @@ func NewApplication(listener net.Listener, handleFunc func(IOSession, interface{ return s, nil } -// NewTCPApplication returns a net application -func NewTCPApplication(addr string, handleFunc func(IOSession, interface{}, uint64) error, opts ...AppOption) (NetApplication, error) { +// NewApplication returns a application +func NewApplication(address string, handleFunc func(IOSession, interface{}, uint64) error, opts ...AppOption) (NetApplication, error) { + network, address, err := parseAdddress(address) + if err != nil { + return nil, err + } + listenConfig := &net.ListenConfig{ Control: listenControl, } - listener, err := listenConfig.Listen(context.TODO(), "tcp4", addr) + + listener, err := listenConfig.Listen(context.TODO(), network, address) if err != nil { return nil, err } - return NewApplication(listener, handleFunc, opts...) + return NewApplicationWithListener(listener, handleFunc, opts...) } func (s *server) Start() error { @@ -183,17 +190,6 @@ func (s *server) doStart() { s.addSession(rs) go func() { - defer func() { - if err := recover(); err != nil { - const size = 64 << 10 - rBuf := make([]byte, size) - rBuf = rBuf[:runtime.Stack(rBuf, false)] - s.opts.logger.Error("connection painc", - zap.Any("err", err), - zap.String("stack", string(rBuf))) - } - }() - defer func() { s.deleteSession(rs) rs.Close() @@ -268,3 +264,20 @@ func (s *server) isStarted() bool { return s.mu.running } + +func parseAdddress(address string) (string, string, error) { + if !strings.Contains(address, "//") { + return "tcp4", address, nil + } + + u, err := url.Parse(address) + if err != nil { + return "", "", err + } + + if strings.ToUpper(u.Scheme) == "UNIX" { + return u.Scheme, u.Path, nil + } + + return u.Scheme, u.Host, nil +} diff --git a/application_test.go b/application_test.go index e1929fa..ae73ba2 100644 --- a/application_test.go +++ b/application_test.go @@ -11,96 +11,134 @@ import ( ) var ( - testAddr = "127.0.0.1:12345" + testAddr = "127.0.0.1:12345" + testUDPAddr = "udp://127.0.0.1:12346" + testUnixSocket = "unix:///tmp/goetty.sock" + + testAddresses = map[string]string{ + "tcp": testAddr, + "unix": testUnixSocket, + } ) func TestStart(t *testing.T) { defer leaktest.AfterTest(t)() - app := newTestTCPApp(t, nil) - defer app.Stop() + for name, address := range testAddresses { + addr := address + t.Run(name, func(t *testing.T) { + app := newTestApp(t, addr, nil) + defer app.Stop() - assert.NoError(t, app.Start()) - assert.NoError(t, app.Start()) + assert.NoError(t, app.Start()) + assert.NoError(t, app.Start()) + }) + } } func TestStop(t *testing.T) { - defer leaktest.AfterTest(t) - - app := newTestTCPApp(t, nil).(*server) - assert.NoError(t, app.Start()) - - n := 10 - for i := 0; i < n; i++ { - session := newTestIOSession(t) - ok, err := session.Connect(testAddr, time.Second) - assert.NoError(t, err) - assert.True(t, ok) - assert.NoError(t, session.WriteAndFlush("test")) - } - - assert.NoError(t, app.Stop()) + defer leaktest.AfterTest(t)() - c := 0 - for _, m := range app.sessions { - m.Lock() - c += len(m.sessions) - m.Unlock() + for name, address := range testAddresses { + addr := address + t.Run(name, func(t *testing.T) { + app := newTestApp(t, addr, nil).(*server) + assert.NoError(t, app.Start()) + + n := 10 + for i := 0; i < n; i++ { + session := newTestIOSession(t) + ok, err := session.Connect(addr, time.Second) + assert.NoError(t, err) + assert.True(t, ok) + assert.NoError(t, session.WriteAndFlush("test")) + } + + assert.NoError(t, app.Stop()) + + c := 0 + for _, m := range app.sessions { + m.Lock() + c += len(m.sessions) + m.Unlock() + } + + assert.Equal(t, 0, c) + }) } - assert.Equal(t, 0, c) } func TestCloseBlock(t *testing.T) { - defer leaktest.AfterTest(t) + defer leaktest.AfterTest(t)() - app := newTestTCPApp(t, nil).(*server) - assert.NoError(t, app.Start()) + for name, address := range testAddresses { + addr := address + t.Run(name, func(t *testing.T) { + app := newTestApp(t, addr, nil).(*server) + assert.NoError(t, app.Start()) + + conn := newTestIOSession(t, WithEnableAsyncWrite(16), WithLogger(zap.NewExample())) + ok, err := conn.Connect(addr, time.Second) + assert.NoError(t, err) + assert.True(t, ok) + assert.NoError(t, app.Stop()) + assert.NoError(t, conn.Write(string(make([]byte, 1024*1024)))) + assert.NoError(t, conn.Close()) + }) + } - conn := newTestIOSession(t, WithEnableAsyncWrite(16), WithLogger(zap.NewExample())) - ok, err := conn.Connect(testAddr, time.Second) - assert.NoError(t, err) - assert.True(t, ok) - assert.NoError(t, app.Stop()) - assert.NoError(t, conn.Write(string(make([]byte, 1024*1024)))) - assert.NoError(t, conn.Close()) } func TestIssue13(t *testing.T) { - defer leaktest.AfterTest(t) - - app := newTestTCPApp(t, nil).(*server) - assert.NoError(t, app.Start()) + defer leaktest.AfterTest(t)() - conn := newTestIOSession(t, WithEnableAsyncWrite(16), WithLogger(zap.NewExample())) - ok, err := conn.Connect(testAddr, time.Second) - assert.NoError(t, err) - assert.True(t, ok) - - errC := make(chan error) - go func() { - _, err := conn.Read() - if err != nil { - errC <- err - return - } - }() + for name, address := range testAddresses { + addr := address + t.Run(name, func(t *testing.T) { + app := newTestApp(t, addr, nil).(*server) + assert.NoError(t, app.Start()) + + conn := newTestIOSession(t, WithEnableAsyncWrite(16), WithLogger(zap.NewExample())) + ok, err := conn.Connect(addr, time.Second) + assert.NoError(t, err) + assert.True(t, ok) + + defer conn.Close() + + errC := make(chan error) + go func() { + _, err := conn.Read() + if err != nil { + errC <- err + return + } + }() + + time.Sleep(time.Millisecond * 100) + assert.NoError(t, app.Stop()) + + select { + case <-errC: + return + case <-time.After(time.Second * 1): + assert.Fail(t, "timeout") + } + }) + } - time.Sleep(time.Millisecond * 100) - assert.NoError(t, app.Stop()) +} - select { - case <-errC: - return - case <-time.After(time.Second * 1): - assert.Fail(t, "timeout") +func newTestApp(t *testing.T, address string, handleFunc func(IOSession, interface{}, uint64) error, opts ...AppOption) NetApplication { + if handleFunc == nil { + handleFunc = func(i1 IOSession, i2 interface{}, u uint64) error { + return nil + } } -} -func newTestTCPApp(t *testing.T, handleFunc func(IOSession, interface{}, uint64) error, opts ...AppOption) NetApplication { encoder, decoder := simple.NewStringCodec() opts = append(opts, WithAppSessionOptions(WithCodec(encoder, decoder))) - app, err := NewTCPApplication(testAddr, handleFunc, opts...) + app, err := NewApplication(address, handleFunc, opts...) assert.NoError(t, err) return app @@ -111,3 +149,20 @@ func newTestIOSession(t *testing.T, opts ...Option) IOSession { opts = append(opts, WithCodec(encoder, decoder)) return NewIOSession(opts...) } + +func TestParseAddress(t *testing.T) { + network, address, err := parseAdddress(testAddr) + assert.NoError(t, err) + assert.Equal(t, "tcp4", network) + assert.Equal(t, testAddr, address) + + network, address, err = parseAdddress(testUDPAddr) + assert.NoError(t, err) + assert.Equal(t, "udp", network) + assert.Equal(t, "127.0.0.1:12346", address) + + network, address, err = parseAdddress(testUnixSocket) + assert.NoError(t, err) + assert.Equal(t, "unix", network) + assert.Equal(t, "/tmp/goetty.sock", address) +} diff --git a/example/echo_server.go b/example/echo_server.go index ae16c14..0aedc2f 100644 --- a/example/echo_server.go +++ b/example/echo_server.go @@ -17,7 +17,7 @@ type EchoServer struct { func NewEchoServer(addr string) *EchoServer { svr := &EchoServer{} encoder, decoder := simple.NewStringCodec() - app, err := goetty.NewTCPApplication(addr, svr.handle, + app, err := goetty.NewApplication(addr, svr.handle, goetty.WithAppSessionOptions(goetty.WithCodec(encoder, decoder))) if err != nil { log.Panicf("start server failed with %+v", err) diff --git a/session.go b/session.go index 6c67489..cae84ea 100644 --- a/session.go +++ b/session.go @@ -111,7 +111,12 @@ func (bio *baseIO) ID() uint64 { return bio.id } -func (bio *baseIO) Connect(addr string, timeout time.Duration) (bool, error) { +func (bio *baseIO) Connect(addressWithNetwork string, timeout time.Duration) (bool, error) { + network, address, err := parseAdddress(addressWithNetwork) + if err != nil { + return false, err + } + if bio.disableConnect { return false, ErrDisableConnect } @@ -138,7 +143,7 @@ func (bio *baseIO) Connect(addr string, timeout time.Duration) (bool, error) { return false, fmt.Errorf("the session is closing or connecting is other goroutine") } - conn, err := net.DialTimeout("tcp", addr, timeout) + conn, err := net.DialTimeout(network, address, timeout) if nil != err { atomic.StoreInt32(&bio.state, stateReadyToConnect) return false, err diff --git a/session_test.go b/session_test.go index cbcee72..25d8773 100644 --- a/session_test.go +++ b/session_test.go @@ -10,83 +10,95 @@ import ( ) func TestNormal(t *testing.T) { - defer leaktest.AfterTest(t) - - var cs IOSession - cnt := uint64(0) - app := newTestTCPApp(t, func(rs IOSession, msg interface{}, received uint64) error { - cs = rs - rs.WriteAndFlush(msg) - atomic.StoreUint64(&cnt, received) - return nil - }) - app.Start() - defer app.Stop() - - client := newTestIOSession(t, WithTimeout(time.Second, time.Second)) - ok, err := client.Connect(testAddr, time.Second) - assert.NoError(t, err) - assert.True(t, ok) - assert.True(t, client.Connected()) - - assert.NoError(t, client.WriteAndFlush("hello")) - reply, err := client.Read() - assert.NoError(t, err) - assert.Equal(t, "hello", reply) - assert.Equal(t, uint64(1), atomic.LoadUint64(&cnt)) - - v, err := app.GetSession(cs.ID()) - assert.NoError(t, err) - assert.NotNil(t, v) - - assert.NoError(t, app.Broadcast("world")) - reply, err = client.Read() - assert.NoError(t, err) - assert.Equal(t, "world", reply) - - assert.NoError(t, client.Close()) - assert.False(t, client.Connected()) - assert.Error(t, client.WriteAndFlush("hello")) - - time.Sleep(time.Millisecond * 100) - v, err = app.GetSession(cs.ID()) - assert.NoError(t, err) - assert.Nil(t, v) - - ok, err = client.Connect(testAddr, time.Second) - assert.NoError(t, err) - assert.True(t, ok) - assert.True(t, client.Connected()) + defer leaktest.AfterTest(t)() + + for name, address := range testAddresses { + addr := address + t.Run(name, func(t *testing.T) { + var cs IOSession + cnt := uint64(0) + app := newTestApp(t, addr, func(rs IOSession, msg interface{}, received uint64) error { + cs = rs + rs.WriteAndFlush(msg) + atomic.StoreUint64(&cnt, received) + return nil + }) + app.Start() + defer app.Stop() + + client := newTestIOSession(t, WithTimeout(time.Second, time.Second)) + ok, err := client.Connect(addr, time.Second) + assert.NoError(t, err) + assert.True(t, ok) + assert.True(t, client.Connected()) + + assert.NoError(t, client.WriteAndFlush("hello")) + reply, err := client.Read() + assert.NoError(t, err) + assert.Equal(t, "hello", reply) + assert.Equal(t, uint64(1), atomic.LoadUint64(&cnt)) + + v, err := app.GetSession(cs.ID()) + assert.NoError(t, err) + assert.NotNil(t, v) + + assert.NoError(t, app.Broadcast("world")) + reply, err = client.Read() + assert.NoError(t, err) + assert.Equal(t, "world", reply) + + assert.NoError(t, client.Close()) + assert.False(t, client.Connected()) + assert.Error(t, client.WriteAndFlush("hello")) + + time.Sleep(time.Millisecond * 100) + v, err = app.GetSession(cs.ID()) + assert.NoError(t, err) + assert.Nil(t, v) + + ok, err = client.Connect(addr, time.Second) + assert.NoError(t, err) + assert.True(t, ok) + assert.True(t, client.Connected()) + }) + } } func TestAsyncWrite(t *testing.T) { - defer leaktest.AfterTest(t) - - app := newTestTCPApp(t, func(rs IOSession, msg interface{}, received uint64) error { - rs.WriteAndFlush(msg) - return nil - }) - app.Start() - defer app.Stop() - - client := newTestIOSession(t, WithTimeout(time.Second, time.Second), WithEnableAsyncWrite(16)) - ok, err := client.Connect(testAddr, time.Second) - assert.NoError(t, err) - assert.True(t, ok) - assert.True(t, client.Connected()) - - assert.NoError(t, client.WriteAndFlush("hello")) - reply, err := client.Read() - assert.NoError(t, err) - assert.Equal(t, "hello", reply) - - assert.NoError(t, client.Close()) - ok, err = client.Connect(testAddr, time.Second) - assert.NoError(t, err) - assert.True(t, ok) - assert.True(t, client.Connected()) - assert.NoError(t, client.WriteAndFlush("hello")) - reply, err = client.Read() - assert.NoError(t, err) - assert.Equal(t, "hello", reply) + defer leaktest.AfterTest(t)() + + for name, address := range testAddresses { + addr := address + t.Run(name, func(t *testing.T) { + app := newTestApp(t, addr, func(rs IOSession, msg interface{}, received uint64) error { + rs.WriteAndFlush(msg) + return nil + }) + app.Start() + defer app.Stop() + + client := newTestIOSession(t, WithTimeout(time.Second, time.Second), WithEnableAsyncWrite(16)) + defer client.Close() + + ok, err := client.Connect(addr, time.Second) + assert.NoError(t, err) + assert.True(t, ok) + assert.True(t, client.Connected()) + + assert.NoError(t, client.WriteAndFlush("hello")) + reply, err := client.Read() + assert.NoError(t, err) + assert.Equal(t, "hello", reply) + + assert.NoError(t, client.Close()) + ok, err = client.Connect(addr, time.Second) + assert.NoError(t, err) + assert.True(t, ok) + assert.True(t, client.Connected()) + assert.NoError(t, client.WriteAndFlush("hello")) + reply, err = client.Read() + assert.NoError(t, err) + assert.Equal(t, "hello", reply) + }) + } }