diff --git a/server.go b/server.go index 98702bf5..4c8739de 100644 --- a/server.go +++ b/server.go @@ -377,6 +377,7 @@ func (server *server) handleClient(client *client) { advertiseTLS = "" } + firstCmd := true for client.isAlive() { switch client.state { case ClientGreeting: @@ -413,6 +414,30 @@ func (server *server) handleClient(client *client) { } cmd := strings.ToUpper(input[:cmdLen]) switch { + case strings.Index(cmd, "PROXY ") == 0: + if firstCmd == false { + client.sendResponse(fmt.Sprintf("%d%s %s", response.ClassPermanentFailure, response.InvalidCommand, "PROXY must be the first command")) + break + } + proxyTokens := strings.Split(cmd, " ") + if len(proxyTokens) <= 1 { + client.sendResponse(fmt.Sprintf("%d%s %s", response.ClassPermanentFailure, response.InvalidCommandArguments, "Invalid PROXY arguments")) + break + } + proxyL4Proto := proxyTokens[1] + switch proxyL4Proto { + case "TCP4": + fallthrough + case "TCP6": + proxyL3Src := proxyTokens[2] + server.log().Debugf("Updating client IP from %s to %s", client.RemoteIP, proxyL3Src) + client.RemoteIP = proxyL3Src + case "UNKNOWN": + default: + client.sendResponse(fmt.Sprintf("%d%s Invalid PROXY protocol: %s", response.ClassPermanentFailure, response.InvalidCommandArguments, proxyL4Proto)) + } + client.sendResponse("") + case strings.Index(cmd, "HELO") == 0: client.Helo = strings.Trim(input[4:], " ") client.resetTransaction() @@ -530,6 +555,7 @@ func (server *server) handleClient(client *client) { client.sendResponse(response.Canned.FailUnrecognizedCmd) } } + firstCmd = false case ClientData: diff --git a/tests/guerrilla_test.go b/tests/guerrilla_test.go index 2ae85b9a..e6b63a5f 100644 --- a/tests/guerrilla_test.go +++ b/tests/guerrilla_test.go @@ -526,6 +526,78 @@ func TestRFC2821LimitDomain(t *testing.T) { os.Truncate("./testlog", 0) } +// Test support for Proxy Protocol +func TestProxyProtocol(t *testing.T) { + if initErr != nil { + t.Error(initErr) + t.FailNow() + } + if startErrors := app.Start(); startErrors == nil { + conn, bufin, err := Connect(config.Servers[0], 20) + hostname := config.Servers[0].Hostname + if err != nil { + // handle error + t.Error(err.Error(), config.Servers[0].ListenInterface) + t.FailNow() + } else { + // Test PROXY header + response, err := Command(conn, bufin, "PROXY TCP4 1.1.1.1 2.2.2.2 12345 9876") + if err != nil { + t.Error("command failed", err.Error()) + } + expected := "" + if strings.Index(response, expected) != 0 { + t.Error("Server did not respond with", expected, ", it said:"+response) + } + // Reset + response, err = Command(conn, bufin, "RSET") + if err != nil { + t.Error("command failed", err.Error()) + } + expected = "250 2.1.0 OK" + if strings.Index(response, expected) != 0 { + t.Error("Server did not respond with", expected, ", it said:"+response) + } + // Start a new transaction + response, err = Command(conn, bufin, "HELO localtester") + if err != nil { + t.Error("command failed", err.Error()) + } + expected = fmt.Sprintf("250 %s Hello", hostname) + if strings.Index(response, expected) != 0 { + t.Error("Server did not respond with", expected, ", it said:"+response) + } + // Send the PROXY header, but not as the first header + response, err = Command(conn, bufin, "PROXY TCP4 1.1.1.1 2.2.2.2 12345 9876") + if err != nil { + t.Error("command failed", err.Error()) + } + expected = "5.5.1 PROXY must be the first command" + if strings.Index(response, expected) != 0 { + t.Error("Server did not respond with", expected, ", it said:"+response) + } + + // be kind, QUIT. And we are sure that bufin does not contain fragments from the EHLO command. + response, err = Command(conn, bufin, "QUIT") + if err != nil { + t.Error("command failed", err.Error()) + } + expected = "221 2.0.0 Bye" + if strings.Index(response, expected) != 0 { + t.Error("Server did not respond with", expected, ", it said:"+response) + } + } + conn.Close() + app.Shutdown() + } else { + if startErrors := app.Start(); startErrors != nil { + t.Error(startErrors) + app.Shutdown() + t.FailNow() + } + } +} + // Test several different inputs to MAIL FROM command func TestMailFromCmd(t *testing.T) { if initErr != nil {