-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdownstream.go
310 lines (265 loc) · 8.13 KB
/
downstream.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
// Implementations for receiving incoming email messages and placing them them
// on a sendable channel for batching/summarizing/processing.
package main
import (
"bufio"
"crypto/tls"
"fmt"
"io"
"log"
"net"
"os"
"sync"
"syscall"
"time"
)
// Listener binds a socket on an address, and accepts email messages via SMTP
// on each incoming connection.
type Listener struct {
Socket ServerSocket
Auth Auth
Security SessionSecurity
TLSConfig *tls.Config
Debug bool
Rewriter AddressRewriter
conns int
}
// ServerSocket is a `net.Listener` that can return its file descriptor.
type ServerSocket interface {
net.Listener
Fd() (uintptr, error)
String() string
}
type SSLServerSocket struct {
net.Listener
orig ServerSocket
}
func NewSSLServerSocket(socket ServerSocket, config *tls.Config) *SSLServerSocket {
ssl := tls.NewListener(socket, config)
return &SSLServerSocket{ssl, socket}
}
func (s *SSLServerSocket) Fd() (uintptr, error) {
return s.orig.Fd()
}
func (s *SSLServerSocket) String() string {
return "ssl:" + s.orig.String()
}
// TCPServerSocket is a ServerSocket implementation for listeners that bind a
// TCP port from an address.
type TCPServerSocket struct {
*net.TCPListener
addr string
}
func NewTCPServerSocket(addr string) (*TCPServerSocket, error) {
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return nil, err
}
ln, err := net.ListenTCP("tcp", tcpAddr)
return &TCPServerSocket{ln, addr}, err
}
func (t *TCPServerSocket) Fd() (uintptr, error) {
if file, err := t.File(); err != nil {
return 0, err
} else {
return file.Fd(), nil
}
}
func (t *TCPServerSocket) String() string {
return t.addr
}
// FileServerSocket is a ServerSocket implementation for listeners that open an
// existing socket by its file descriptor.
type FileServerSocket struct {
net.Listener
}
func NewFileServerSocket(fd uintptr) (*FileServerSocket, error) {
file := os.NewFile(fd, "socket")
ln, err := net.FileListener(file)
// We used to syscall.Close(int(fd)) here because FileListener dups it, and
// we don't need the original anymore after that.
//
// It turns out that doing that can put the underlying socket in a bad
// state in some tricky that I don't fully understand. The downside to not
// calling it is that we have two open FDs pointing to the same socket
// during the lifetime of the program, but only one (the correct one) ends
// up being inherited on reload, so this is actually fine.
return &FileServerSocket{ln}, err
}
func (f *FileServerSocket) Fd() (uintptr, error) {
if tcpListener, ok := f.Listener.(*net.TCPListener); !ok {
return 0, fmt.Errorf("%s is not a TCP socket", f)
} else if file, err := tcpListener.File(); err != nil {
return 0, err
} else {
return file.Fd(), err
}
}
func (f *FileServerSocket) String() string {
return fmt.Sprintf("fd from file")
}
// Calls `Wait()` on a `sync.WaitGroup`, blocking for no more than the timeout.
// Returns true if the call to `Wait()` returned before hitting the timeout, or
// false otherwise.
func WaitWithTimeout(waitGroup *sync.WaitGroup, timeout time.Duration) bool {
done := make(chan interface{}, 0)
go func() {
waitGroup.Wait()
done <- nil
}()
timer := time.After(timeout)
for {
select {
case <-timer:
return false
case <-done:
return true
}
}
}
// Listens on a TCP port, putting all messages received via SMTP onto the
// `received` channel.
func (l *Listener) Listen(received chan<- *StorageRequest, done <-chan TerminationRequest, shutdownTimeout time.Duration) (uintptr, error) {
log.Printf("listening: %s", l.Socket)
waitGroup := new(sync.WaitGroup)
acceptFinished := make(chan bool, 0)
// Accept connections in a goroutine, and add them to the WaitGroup.
go func() {
for {
conn, err := l.Socket.Accept()
if err != nil {
log.Printf("error accepting connection: %s", err)
break
}
l.conns += 1
// Handle each incoming connection in its own goroutine.
log.Printf("handling new connection from %s", conn.RemoteAddr())
waitGroup.Add(1)
go func() {
defer waitGroup.Done()
l.handleConnection(conn, received)
log.Printf("done handling new connection from %s", conn.RemoteAddr())
}()
}
// When we've broken out of the loop for any reason (errors, limit),
// signal that we're done via the channel.
acceptFinished <- true
}()
newFd := 0
// Wait for either a shutdown/reload request, or for the Accept() loop to
// break on its own (from error or a limit).
select {
case req := <-done:
// If we got a reload request, set up a file descriptor to pass to the
// reloaded process.
if req == Reload {
fd, err := l.Socket.Fd()
if err != nil {
return 0, err
}
// If we don't dup the fd, closing it below (to break the Accept()
// loop) will prevent us from being able to use it as a socket in
// the child process.
newFd, err = syscall.Dup(int(fd))
if err != nil {
return 0, err
}
// If we don't mark the new fd as CLOEXEC, the child process will
// inherit it twice (the second one being the one passed to
// ExtraFiles).
syscall.CloseOnExec(newFd)
}
log.Printf("closing listening socket")
if err := l.Socket.Close(); err != nil {
return 0, err
}
// Wait for the Close() to break us out of the Accept() loop.
<-acceptFinished
case <-acceptFinished:
// If the accept loop is done on its own (e.g. not from a reload
// request), fall through to do some cleanup.
}
// Wait for any open sesssions to finish, or time out.
log.Printf("waiting %s for open connections to finish", shutdownTimeout)
WaitWithTimeout(waitGroup, shutdownTimeout)
close(received)
return uintptr(newFd), nil
}
// handleConnection reads SMTP commands from a socket and writes back SMTP
// responses. Since it takes several commands (MAIL, RCPT, DATA) to fully
// describe a message, `Session` is used to keep track of the progress building
// a message. When a message has been fully communicated by a downstream
// client, it's put on the `received` channel for later batching/summarizing.
func (l *Listener) handleConnection(conn io.ReadWriteCloser, received chan<- *StorageRequest) {
defer conn.Close()
origReader := bufio.NewReader(conn)
origWriter := bufio.NewWriter(conn)
// In debug mode, wrap the readers and writers.
var reader stringReader
var writer stringWriter
if l.Debug {
prefix := fmt.Sprintf("%v ", conn)
reader = &debugReader{origReader, prefix}
writer = &debugWriter{origWriter, prefix}
} else {
reader = origReader
writer = origWriter
}
session := new(Session)
if err := session.Start(l.Auth, l.Security).WriteTo(writer); err != nil {
log.Printf("error writing to client: %s", err)
return
}
for {
resp, err := session.ReadCommand(reader)
if err != nil {
log.Printf("error reading from client: %s", err)
break
}
if err := resp.WriteTo(writer); err != nil {
log.Printf("error writing to client after reading command: %s", err)
break
}
switch {
case resp.IsClose():
return
case resp.NeedsData():
resp, msg := session.ReadData(reader)
if msg != nil {
log.Printf("received message with subject %#v", msg.Parsed.Header.Get("Subject"))
msg.RedirectedTo = l.Rewriter.RewriteAll(msg.To)
errors := make(chan error, 0)
received <- &StorageRequest{msg, errors}
if err := <-errors; err != nil {
errorResp := Response{451, err.Error()}
if err := errorResp.WriteTo(writer); err != nil {
log.Printf("error writing to client after storage failure: %s", err)
break
}
} else {
if err := resp.WriteTo(writer); err != nil {
log.Printf("error writing to client after reading data: %s", err)
break
}
}
}
case resp.NeedsAuthResponse():
resp := session.ReadAuthResponse(reader)
if err := resp.WriteTo(writer); err != nil {
log.Printf("error writing to client after reading auth: %s", err)
break
}
case resp.StartsTLS():
netConn, ok := conn.(net.Conn)
if !ok {
log.Printf("error getting underlying connection for STARTTLS")
return
}
tlsConn := tls.Server(netConn, l.TLSConfig)
origReader.Reset(tlsConn)
origWriter.Reset(tlsConn)
session.security = TLS_POST_STARTTLS
defer tlsConn.Close()
}
}
}