-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathproxy.go
125 lines (110 loc) · 2.48 KB
/
proxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
package goetty
import (
"errors"
"io"
"sync"
"time"
"go.uber.org/zap"
)
// Proxy simple reverse proxy
type Proxy interface {
// Start start the proxy
Start() error
// Stop stop the proxy
Stop() error
// AddUpStream add upstream
AddUpStream(address string, connectTimeout time.Duration)
}
// NewProxy returns a simple tcp proxy
func NewProxy[IN any, OUT any](address string, logger *zap.Logger) Proxy {
return &proxy[IN, OUT]{
address: address,
logger: adjustLogger(logger),
}
}
type proxy[IN any, OUT any] struct {
logger *zap.Logger
address string
server NetApplication[IN, OUT]
mu struct {
sync.Mutex
seq uint64
upstreamList []*upstream
}
}
func (p *proxy[IN, OUT]) Start() error {
server, err := NewApplication(
p.address,
nil,
WithAppHandleSessionFunc(p.handleSession))
if err != nil {
return err
}
p.server = server
return p.server.Start()
}
func (p *proxy[IN, OUT]) Stop() error {
return p.server.Stop()
}
func (p *proxy[IN, OUT]) AddUpStream(address string, connectTimeout time.Duration) {
p.mu.Lock()
defer p.mu.Unlock()
p.mu.upstreamList = append(p.mu.upstreamList, &upstream{
address: address,
connectTimeout: connectTimeout,
})
}
func (p *proxy[IN, OUT]) getUpStream() *upstream {
p.mu.Lock()
defer p.mu.Unlock()
n := uint64(len(p.mu.upstreamList))
if n == 0 {
return nil
}
up := p.mu.upstreamList[p.mu.seq%n]
p.mu.seq++
return up
}
func (p *proxy[IN, OUT]) handleSession(conn IOSession[IN, OUT]) error {
upstream := p.getUpStream()
if upstream == nil {
return errors.New("no upstream")
}
upstreamConn := NewIOSession[IN, OUT]()
err := upstreamConn.Connect(upstream.address, upstream.connectTimeout)
if err != nil {
return err
}
defer func() {
if err := upstreamConn.Close(); err != nil {
p.logger.Error("close upstream failed",
zap.String("upstream", upstream.address),
zap.Error(err))
}
}()
srcConn := conn.RawConn()
dstConn := upstreamConn.RawConn()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_, err := io.Copy(srcConn, dstConn)
if err != nil {
p.logger.Error("copy data from upstream to client failed",
zap.String("upstream", upstream.address),
zap.Error(err))
}
}()
_, err = io.Copy(dstConn, srcConn)
if err != nil {
p.logger.Error("copy data from client to upstream failed",
zap.String("upstream", upstream.address),
zap.Error(err))
}
wg.Wait()
return err
}
type upstream struct {
address string
connectTimeout time.Duration
}