diff --git a/src/main/java/core/packetproxy/ProxySSLTransparent.java b/src/main/java/core/packetproxy/ProxySSLTransparent.java index 31e77e1..15108f1 100644 --- a/src/main/java/core/packetproxy/ProxySSLTransparent.java +++ b/src/main/java/core/packetproxy/ProxySSLTransparent.java @@ -138,7 +138,7 @@ private void checkTransparentSSLProxy(Socket client, int proxyPort) throws Excep WrapEndpoint wep_e = new WrapEndpoint(client_e, ArrayUtils.subarray(buff, 0, length)); InetSocketAddress serverAddr = new InetSocketAddress(serverName, proxyPort); Server server = Servers.getInstance().queryByAddress(serverAddr); - SSLSocketEndpoint server_e = new SSLSocketEndpoint(serverAddr, serverName); + SSLSocketEndpoint server_e = new SSLSocketEndpoint(serverAddr, serverName, null); createConnection(wep_e, server_e, server); } else { diff --git a/src/main/java/core/packetproxy/common/EndpointFactory.java b/src/main/java/core/packetproxy/common/EndpointFactory.java index 86a6ec4..5c688e9 100644 --- a/src/main/java/core/packetproxy/common/EndpointFactory.java +++ b/src/main/java/core/packetproxy/common/EndpointFactory.java @@ -22,7 +22,6 @@ import javax.net.ssl.SSLSocket; -import packetproxy.http.Http; import packetproxy.http.Https; import packetproxy.model.OneShotPacket; import packetproxy.model.Server; @@ -61,7 +60,7 @@ public static Endpoint createFromURI(String uri) throws Exception { String host = u.getHost(); int port = u.getPort() > 0 ? u.getPort() : 80; if (u.getScheme().equalsIgnoreCase("https")) { - return new SSLSocketEndpoint(new InetSocketAddress(host, port), host); + return new SSLSocketEndpoint(new InetSocketAddress(host, port), host, null); } else if (u.getScheme().equalsIgnoreCase("http")) { return new SocketEndpoint(new InetSocketAddress(host, port)); } else { @@ -71,24 +70,16 @@ public static Endpoint createFromURI(String uri) throws Exception { public static Endpoint createFromOneShotPacket(OneShotPacket packet) throws Exception { if (packet.getUseSSL()) { - return new SSLSocketEndpoint(packet.getServer(), packet.getServerName()); + return new SSLSocketEndpoint(packet.getServer(), packet.getServerName(), packet.getAlpn()); } else { // nc など複数同時接続を受け付けないconnection用に10秒でtimeoutする return new SocketEndpoint(packet.getServer(), 10 * 1000); } } - public static Endpoint createServerEndpointFromHttp(Http http) throws Exception { - if (http.isProxySsl()) { - return new SSLSocketEndpoint(http.getServerAddr(), http.getServerName()); - } else { - return new SocketEndpoint(http.getServerAddr()); - } - } - public static Endpoint createFromServer(Server server) throws Exception { if (server.getUseSSL()) { - return new SSLSocketEndpoint(server.getAddress(), server.getIp()); + return new SSLSocketEndpoint(server.getAddress(), server.getIp(), null); } else { return new SocketEndpoint(server.getAddress()); } diff --git a/src/main/java/core/packetproxy/common/SSLSocketEndpoint.java b/src/main/java/core/packetproxy/common/SSLSocketEndpoint.java index 0e8e8ea..afb3a18 100644 --- a/src/main/java/core/packetproxy/common/SSLSocketEndpoint.java +++ b/src/main/java/core/packetproxy/common/SSLSocketEndpoint.java @@ -27,20 +27,24 @@ public class SSLSocketEndpoint implements Endpoint { protected SSLSocket socket; protected String server_name; + protected String alpn; public SSLSocketEndpoint(SSLSocketEndpoint ep) { this.server_name = ep.server_name; this.socket = ep.socket; + this.alpn = ep.alpn; } public SSLSocketEndpoint(SSLSocket socket, String SNIServerName) { this.server_name = SNIServerName; this.socket = socket; + this.alpn = socket.getApplicationProtocol(); } - public SSLSocketEndpoint(InetSocketAddress addr, String SNIServerName) throws Exception { + public SSLSocketEndpoint(InetSocketAddress addr, String SNIServerName, String alpn) throws Exception { this.server_name = SNIServerName; - this.socket = Https.createClientSSLSocket(addr, SNIServerName); + this.alpn = alpn; + this.socket = Https.createClientSSLSocket(addr, SNIServerName, alpn); } @Override diff --git a/src/main/java/core/packetproxy/http/Https.java b/src/main/java/core/packetproxy/http/Https.java index f6de2f0..f008c0e 100644 --- a/src/main/java/core/packetproxy/http/Https.java +++ b/src/main/java/core/packetproxy/http/Https.java @@ -212,29 +212,39 @@ public PrivateKey getPrivateKey(String s) { } }}; - public static SSLSocket convertToClientSSLSocket(Socket socket) throws Exception { + public static SSLSocket convertToClientSSLSocket(Socket socket, String alpn) throws Exception { SSLSocketFactory ssf = createSSLSocketFactory(); SSLSocket sock = (SSLSocket) ssf.createSocket(socket, null, socket.getPort(), false); SSLParameters sslp = sock.getSSLParameters(); - String[] clientAPs ={ "h2", "http/1.1", "http/1.0" }; + String[] clientAPs; + if (alpn != null && alpn.length() > 0) { + clientAPs = new String[]{ alpn }; + } else { + clientAPs = new String[]{ "h2", "http/1.1", "http/1.0" }; + } sslp.setApplicationProtocols(clientAPs); sock.setSSLParameters(sslp); sock.startHandshake(); return sock; } - public static SSLSocket createClientSSLSocket(InetSocketAddress addr) throws Exception { + public static SSLSocket createClientSSLSocket(InetSocketAddress addr, String alpn) throws Exception { SSLSocketFactory ssf = createSSLSocketFactory(); SSLSocket sock = (SSLSocket) ssf.createSocket(addr.getAddress(), addr.getPort()); SSLParameters sslp = sock.getSSLParameters(); - String[] clientAPs ={ "h2", "http/1.1", "http/1.0" }; + String[] clientAPs; + if (alpn != null && alpn.length() > 0) { + clientAPs = new String[]{ alpn }; + } else { + clientAPs = new String[]{ "h2", "http/1.1", "http/1.0" }; + } sslp.setApplicationProtocols(clientAPs); sock.setSSLParameters(sslp); sock.startHandshake(); return sock; } - public static SSLSocket createClientSSLSocket(InetSocketAddress addr, String SNIServerName) throws Exception { + public static SSLSocket createClientSSLSocket(InetSocketAddress addr, String SNIServerName, String alpn) throws Exception { /* SNI */ SNIHostName serverName = new SNIHostName(SNIServerName); /* Fetch Client Certificate from ClientKeyManager */ @@ -244,7 +254,12 @@ public static SSLSocket createClientSSLSocket(InetSocketAddress addr, String SNI SSLSocketFactory ssf = createSSLSocketFactory(); SSLSocket sock = (SSLSocket) ssf.createSocket(addr.getAddress(), addr.getPort()); SSLParameters sslp = sock.getSSLParameters(); - String[] clientAPs ={ "h2", "http/1.1", "http/1.0" }; + String[] clientAPs; + if (alpn != null && alpn.length() > 0) { + clientAPs = new String[]{ alpn }; + } else { + clientAPs = new String[]{ "h2", "http/1.1", "http/1.0" }; + } sslp.setApplicationProtocols(clientAPs); sock.setSSLParameters(sslp);