Skip to content

Commit

Permalink
Support unix socket (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxu19830126 authored Jan 18, 2022
1 parent efdfc7b commit 9611a16
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 160 deletions.
49 changes: 31 additions & 18 deletions application.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import (
"errors"
"io"
"net"
"runtime"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -44,8 +45,8 @@ type server struct {
}
}

// NewApplication returns a net application with listener
func NewApplication(listener net.Listener, handleFunc func(IOSession, interface{}, uint64) error, opts ...AppOption) (NetApplication, error) {
// NewApplicationWithListener returns a net application with listener
func NewApplicationWithListener(listener net.Listener, handleFunc func(IOSession, interface{}, uint64) error, opts ...AppOption) (NetApplication, error) {
s := &server{
listener: listener,
handleFunc: handleFunc,
Expand All @@ -70,17 +71,23 @@ func NewApplication(listener net.Listener, handleFunc func(IOSession, interface{
return s, nil
}

// NewTCPApplication returns a net application
func NewTCPApplication(addr string, handleFunc func(IOSession, interface{}, uint64) error, opts ...AppOption) (NetApplication, error) {
// NewApplication returns a application
func NewApplication(address string, handleFunc func(IOSession, interface{}, uint64) error, opts ...AppOption) (NetApplication, error) {
network, address, err := parseAdddress(address)
if err != nil {
return nil, err
}

listenConfig := &net.ListenConfig{
Control: listenControl,
}
listener, err := listenConfig.Listen(context.TODO(), "tcp4", addr)

listener, err := listenConfig.Listen(context.TODO(), network, address)
if err != nil {
return nil, err
}

return NewApplication(listener, handleFunc, opts...)
return NewApplicationWithListener(listener, handleFunc, opts...)
}

func (s *server) Start() error {
Expand Down Expand Up @@ -183,17 +190,6 @@ func (s *server) doStart() {
s.addSession(rs)

go func() {
defer func() {
if err := recover(); err != nil {
const size = 64 << 10
rBuf := make([]byte, size)
rBuf = rBuf[:runtime.Stack(rBuf, false)]
s.opts.logger.Error("connection painc",
zap.Any("err", err),
zap.String("stack", string(rBuf)))
}
}()

defer func() {
s.deleteSession(rs)
rs.Close()
Expand Down Expand Up @@ -268,3 +264,20 @@ func (s *server) isStarted() bool {

return s.mu.running
}

func parseAdddress(address string) (string, string, error) {
if !strings.Contains(address, "//") {
return "tcp4", address, nil
}

u, err := url.Parse(address)
if err != nil {
return "", "", err
}

if strings.ToUpper(u.Scheme) == "UNIX" {
return u.Scheme, u.Path, nil
}

return u.Scheme, u.Host, nil
}
181 changes: 118 additions & 63 deletions application_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,96 +11,134 @@ import (
)

var (
testAddr = "127.0.0.1:12345"
testAddr = "127.0.0.1:12345"
testUDPAddr = "udp://127.0.0.1:12346"
testUnixSocket = "unix:///tmp/goetty.sock"

testAddresses = map[string]string{
"tcp": testAddr,
"unix": testUnixSocket,
}
)

func TestStart(t *testing.T) {
defer leaktest.AfterTest(t)()

app := newTestTCPApp(t, nil)
defer app.Stop()
for name, address := range testAddresses {
addr := address
t.Run(name, func(t *testing.T) {
app := newTestApp(t, addr, nil)
defer app.Stop()

assert.NoError(t, app.Start())
assert.NoError(t, app.Start())
assert.NoError(t, app.Start())
assert.NoError(t, app.Start())
})
}
}

func TestStop(t *testing.T) {
defer leaktest.AfterTest(t)

app := newTestTCPApp(t, nil).(*server)
assert.NoError(t, app.Start())

n := 10
for i := 0; i < n; i++ {
session := newTestIOSession(t)
ok, err := session.Connect(testAddr, time.Second)
assert.NoError(t, err)
assert.True(t, ok)
assert.NoError(t, session.WriteAndFlush("test"))
}

assert.NoError(t, app.Stop())
defer leaktest.AfterTest(t)()

c := 0
for _, m := range app.sessions {
m.Lock()
c += len(m.sessions)
m.Unlock()
for name, address := range testAddresses {
addr := address
t.Run(name, func(t *testing.T) {
app := newTestApp(t, addr, nil).(*server)
assert.NoError(t, app.Start())

n := 10
for i := 0; i < n; i++ {
session := newTestIOSession(t)
ok, err := session.Connect(addr, time.Second)
assert.NoError(t, err)
assert.True(t, ok)
assert.NoError(t, session.WriteAndFlush("test"))
}

assert.NoError(t, app.Stop())

c := 0
for _, m := range app.sessions {
m.Lock()
c += len(m.sessions)
m.Unlock()
}

assert.Equal(t, 0, c)
})
}

assert.Equal(t, 0, c)
}

func TestCloseBlock(t *testing.T) {
defer leaktest.AfterTest(t)
defer leaktest.AfterTest(t)()

app := newTestTCPApp(t, nil).(*server)
assert.NoError(t, app.Start())
for name, address := range testAddresses {
addr := address
t.Run(name, func(t *testing.T) {
app := newTestApp(t, addr, nil).(*server)
assert.NoError(t, app.Start())

conn := newTestIOSession(t, WithEnableAsyncWrite(16), WithLogger(zap.NewExample()))
ok, err := conn.Connect(addr, time.Second)
assert.NoError(t, err)
assert.True(t, ok)
assert.NoError(t, app.Stop())
assert.NoError(t, conn.Write(string(make([]byte, 1024*1024))))
assert.NoError(t, conn.Close())
})
}

conn := newTestIOSession(t, WithEnableAsyncWrite(16), WithLogger(zap.NewExample()))
ok, err := conn.Connect(testAddr, time.Second)
assert.NoError(t, err)
assert.True(t, ok)
assert.NoError(t, app.Stop())
assert.NoError(t, conn.Write(string(make([]byte, 1024*1024))))
assert.NoError(t, conn.Close())
}

func TestIssue13(t *testing.T) {
defer leaktest.AfterTest(t)

app := newTestTCPApp(t, nil).(*server)
assert.NoError(t, app.Start())
defer leaktest.AfterTest(t)()

conn := newTestIOSession(t, WithEnableAsyncWrite(16), WithLogger(zap.NewExample()))
ok, err := conn.Connect(testAddr, time.Second)
assert.NoError(t, err)
assert.True(t, ok)

errC := make(chan error)
go func() {
_, err := conn.Read()
if err != nil {
errC <- err
return
}
}()
for name, address := range testAddresses {
addr := address
t.Run(name, func(t *testing.T) {
app := newTestApp(t, addr, nil).(*server)
assert.NoError(t, app.Start())

conn := newTestIOSession(t, WithEnableAsyncWrite(16), WithLogger(zap.NewExample()))
ok, err := conn.Connect(addr, time.Second)
assert.NoError(t, err)
assert.True(t, ok)

defer conn.Close()

errC := make(chan error)
go func() {
_, err := conn.Read()
if err != nil {
errC <- err
return
}
}()

time.Sleep(time.Millisecond * 100)
assert.NoError(t, app.Stop())

select {
case <-errC:
return
case <-time.After(time.Second * 1):
assert.Fail(t, "timeout")
}
})
}

time.Sleep(time.Millisecond * 100)
assert.NoError(t, app.Stop())
}

select {
case <-errC:
return
case <-time.After(time.Second * 1):
assert.Fail(t, "timeout")
func newTestApp(t *testing.T, address string, handleFunc func(IOSession, interface{}, uint64) error, opts ...AppOption) NetApplication {
if handleFunc == nil {
handleFunc = func(i1 IOSession, i2 interface{}, u uint64) error {
return nil
}
}
}

func newTestTCPApp(t *testing.T, handleFunc func(IOSession, interface{}, uint64) error, opts ...AppOption) NetApplication {
encoder, decoder := simple.NewStringCodec()
opts = append(opts, WithAppSessionOptions(WithCodec(encoder, decoder)))
app, err := NewTCPApplication(testAddr, handleFunc, opts...)
app, err := NewApplication(address, handleFunc, opts...)
assert.NoError(t, err)

return app
Expand All @@ -111,3 +149,20 @@ func newTestIOSession(t *testing.T, opts ...Option) IOSession {
opts = append(opts, WithCodec(encoder, decoder))
return NewIOSession(opts...)
}

func TestParseAddress(t *testing.T) {
network, address, err := parseAdddress(testAddr)
assert.NoError(t, err)
assert.Equal(t, "tcp4", network)
assert.Equal(t, testAddr, address)

network, address, err = parseAdddress(testUDPAddr)
assert.NoError(t, err)
assert.Equal(t, "udp", network)
assert.Equal(t, "127.0.0.1:12346", address)

network, address, err = parseAdddress(testUnixSocket)
assert.NoError(t, err)
assert.Equal(t, "unix", network)
assert.Equal(t, "/tmp/goetty.sock", address)
}
2 changes: 1 addition & 1 deletion example/echo_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type EchoServer struct {
func NewEchoServer(addr string) *EchoServer {
svr := &EchoServer{}
encoder, decoder := simple.NewStringCodec()
app, err := goetty.NewTCPApplication(addr, svr.handle,
app, err := goetty.NewApplication(addr, svr.handle,
goetty.WithAppSessionOptions(goetty.WithCodec(encoder, decoder)))
if err != nil {
log.Panicf("start server failed with %+v", err)
Expand Down
9 changes: 7 additions & 2 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,12 @@ func (bio *baseIO) ID() uint64 {
return bio.id
}

func (bio *baseIO) Connect(addr string, timeout time.Duration) (bool, error) {
func (bio *baseIO) Connect(addressWithNetwork string, timeout time.Duration) (bool, error) {
network, address, err := parseAdddress(addressWithNetwork)
if err != nil {
return false, err
}

if bio.disableConnect {
return false, ErrDisableConnect
}
Expand All @@ -138,7 +143,7 @@ func (bio *baseIO) Connect(addr string, timeout time.Duration) (bool, error) {
return false, fmt.Errorf("the session is closing or connecting is other goroutine")
}

conn, err := net.DialTimeout("tcp", addr, timeout)
conn, err := net.DialTimeout(network, address, timeout)
if nil != err {
atomic.StoreInt32(&bio.state, stateReadyToConnect)
return false, err
Expand Down
Loading

0 comments on commit 9611a16

Please sign in to comment.