Skip to content

Commit

Permalink
okhttp: Add support for file-based private keys
Browse files Browse the repository at this point in the history
  • Loading branch information
ejona86 committed Jul 1, 2022
1 parent bc50adf commit 2cb2fe5
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 42 deletions.
1 change: 1 addition & 0 deletions okhttp/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ java_library(
deps = [
"//api",
"//core:internal",
"//core:util",
"@com_google_code_findbugs_jsr305//jar",
"@com_google_errorprone_error_prone_annotations//jar",
"@com_google_guava_guava//jar",
Expand Down
58 changes: 52 additions & 6 deletions okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,14 @@
import io.grpc.okhttp.internal.ConnectionSpec;
import io.grpc.okhttp.internal.Platform;
import io.grpc.okhttp.internal.TlsVersion;
import io.grpc.util.CertificateUtils;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.security.cert.CertificateFactory;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.EnumSet;
import java.util.Set;
Expand All @@ -73,6 +74,7 @@
import javax.net.SocketFactory;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
Expand Down Expand Up @@ -597,7 +599,16 @@ static SslSocketFactoryResult sslSocketFactoryFrom(ChannelCredentials creds) {
if (tlsCreds.getKeyManagers() != null) {
km = tlsCreds.getKeyManagers().toArray(new KeyManager[0]);
} else if (tlsCreds.getPrivateKey() != null) {
return SslSocketFactoryResult.error("byte[]-based private key unsupported. Use KeyManager");
if (tlsCreds.getPrivateKeyPassword() != null) {
return SslSocketFactoryResult.error("byte[]-based private key with password unsupported. "
+ "Use unencrypted file or KeyManager");
}
try {
km = createKeyManager(tlsCreds.getCertificateChain(), tlsCreds.getPrivateKey());
} catch (GeneralSecurityException gse) {
log.log(Level.FINE, "Exception loading private key from credential", gse);
return SslSocketFactoryResult.error("Unable to load private key: " + gse.getMessage());
}
} // else don't have a client cert
TrustManager[] tm = null;
if (tlsCreds.getTrustManagers() != null) {
Expand Down Expand Up @@ -652,6 +663,39 @@ static SslSocketFactoryResult sslSocketFactoryFrom(ChannelCredentials creds) {
}
}

static KeyManager[] createKeyManager(byte[] certChain, byte[] privateKey)
throws GeneralSecurityException {
X509Certificate[] chain;
ByteArrayInputStream inCertChain = new ByteArrayInputStream(certChain);
try {
chain = CertificateUtils.getX509Certificates(inCertChain);
} finally {
GrpcUtil.closeQuietly(inCertChain);
}
PrivateKey key;
ByteArrayInputStream inPrivateKey = new ByteArrayInputStream(privateKey);
try {
key = CertificateUtils.getPrivateKey(inPrivateKey);
} catch (IOException uee) {
throw new GeneralSecurityException("Unable to decode private key", uee);
} finally {
GrpcUtil.closeQuietly(inPrivateKey);
}
KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
try {
ks.load(null, null);
} catch (IOException ex) {
// Shouldn't really happen, as we're not loading any data.
throw new GeneralSecurityException(ex);
}
ks.setKeyEntry("key", key, new char[0], chain);

KeyManagerFactory keyManagerFactory =
KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
keyManagerFactory.init(ks, new char[0]);
return keyManagerFactory.getKeyManagers();
}

static TrustManager[] createTrustManager(byte[] rootCerts) throws GeneralSecurityException {
KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
try {
Expand All @@ -660,15 +704,17 @@ static TrustManager[] createTrustManager(byte[] rootCerts) throws GeneralSecurit
// Shouldn't really happen, as we're not loading any data.
throw new GeneralSecurityException(ex);
}
CertificateFactory cf = CertificateFactory.getInstance("X.509");
X509Certificate[] certs;
ByteArrayInputStream in = new ByteArrayInputStream(rootCerts);
try {
X509Certificate cert = (X509Certificate) cf.generateCertificate(in);
X500Principal principal = cert.getSubjectX500Principal();
ks.setCertificateEntry(principal.getName("RFC2253"), cert);
certs = CertificateUtils.getX509Certificates(in);
} finally {
GrpcUtil.closeQuietly(in);
}
for (X509Certificate cert : certs) {
X500Principal principal = cert.getSubjectX500Principal();
ks.setCertificateEntry(principal.getName("RFC2253"), cert);
}

TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
Expand Down
47 changes: 13 additions & 34 deletions okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,9 @@
import io.grpc.internal.SharedResourcePool;
import io.grpc.internal.TransportTracer;
import io.grpc.okhttp.internal.Platform;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.EnumSet;
import java.util.List;
import java.util.Set;
Expand All @@ -57,8 +52,6 @@
import javax.net.ssl.KeyManager;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.security.auth.x500.X500Principal;

/**
* Build servers with the OkHttp transport.
Expand Down Expand Up @@ -287,15 +280,25 @@ static HandshakerSocketFactoryResult handshakerSocketFactoryFrom(ServerCredentia
if (tlsCreds.getKeyManagers() != null) {
km = tlsCreds.getKeyManagers().toArray(new KeyManager[0]);
} else if (tlsCreds.getPrivateKey() != null) {
return HandshakerSocketFactoryResult.error(
"byte[]-based private key unsupported. Use KeyManager");
if (tlsCreds.getPrivateKeyPassword() != null) {
return HandshakerSocketFactoryResult.error("byte[]-based private key with password "
+ "unsupported. Use unencrypted file or KeyManager");
}
try {
km = OkHttpChannelBuilder.createKeyManager(
tlsCreds.getCertificateChain(), tlsCreds.getPrivateKey());
} catch (GeneralSecurityException gse) {
log.log(Level.FINE, "Exception loading private key from credential", gse);
return HandshakerSocketFactoryResult.error(
"Unable to load private key: " + gse.getMessage());
}
} // else don't have a client cert
TrustManager[] tm = null;
if (tlsCreds.getTrustManagers() != null) {
tm = tlsCreds.getTrustManagers().toArray(new TrustManager[0]);
} else if (tlsCreds.getRootCertificates() != null) {
try {
tm = createTrustManager(tlsCreds.getRootCertificates());
tm = OkHttpChannelBuilder.createTrustManager(tlsCreds.getRootCertificates());
} catch (GeneralSecurityException gse) {
log.log(Level.FINE, "Exception loading root certificates from credential", gse);
return HandshakerSocketFactoryResult.error(
Expand Down Expand Up @@ -341,30 +344,6 @@ static HandshakerSocketFactoryResult handshakerSocketFactoryFrom(ServerCredentia
}
}

static TrustManager[] createTrustManager(byte[] rootCerts) throws GeneralSecurityException {
KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
try {
ks.load(null, null);
} catch (IOException ex) {
// Shouldn't really happen, as we're not loading any data.
throw new GeneralSecurityException(ex);
}
CertificateFactory cf = CertificateFactory.getInstance("X.509");
ByteArrayInputStream in = new ByteArrayInputStream(rootCerts);
try {
X509Certificate cert = (X509Certificate) cf.generateCertificate(in);
X500Principal principal = cert.getSubjectX500Principal();
ks.setCertificateEntry(principal.getName("RFC2253"), cert);
} finally {
GrpcUtil.closeQuietly(in);
}

TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(ks);
return trustManagerFactory.getTrustManagers();
}

static final class HandshakerSocketFactoryResult {
public final HandshakerSocketFactory factory;
public final String error;
Expand Down
57 changes: 55 additions & 2 deletions okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,62 @@ public void sslSocketFactoryFrom_tls_mtls() throws Exception {
}

@Test
public void sslSocketFactoryFrom_tls_mtls_byteKeyUnsupported() throws Exception {
public void sslSocketFactoryFrom_tls_mtls_keyFile() throws Exception {
SelfSignedCertificate cert = new SelfSignedCertificate(TestUtils.TEST_SERVER_HOST);
KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
keyStore.load(null);
keyStore.setKeyEntry("mykey", cert.key(), new char[0], new Certificate[] {cert.cert()});
KeyManagerFactory keyManagerFactory =
KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
keyManagerFactory.init(keyStore, new char[0]);

KeyStore certStore = KeyStore.getInstance(KeyStore.getDefaultType());
certStore.load(null);
certStore.setCertificateEntry("mycert", cert.cert());
TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(certStore);

SSLContext serverContext = SSLContext.getInstance("TLS");
serverContext.init(
keyManagerFactory.getKeyManagers(), trustManagerFactory.getTrustManagers(), null);
final SSLServerSocket serverListenSocket =
(SSLServerSocket) serverContext.getServerSocketFactory().createServerSocket(0);
serverListenSocket.setNeedClientAuth(true);
final SettableFuture<SSLSocket> serverSocket = SettableFuture.create();
new Thread(new Runnable() {
@Override public void run() {
try {
SSLSocket socket = (SSLSocket) serverListenSocket.accept();
socket.getSession(); // Force handshake
serverSocket.set(socket);
serverListenSocket.close();
} catch (Throwable t) {
serverSocket.setException(t);
}
}
}).start();

ChannelCredentials creds = TlsChannelCredentials.newBuilder()
.keyManager(cert.certificate(), cert.privateKey())
.trustManager(cert.certificate())
.build();
OkHttpChannelBuilder.SslSocketFactoryResult result =
OkHttpChannelBuilder.sslSocketFactoryFrom(creds);
SSLSocket socket =
(SSLSocket) result.factory.createSocket("localhost", serverListenSocket.getLocalPort());
socket.getSession(); // Force handshake
assertThat(((X500Principal) serverSocket.get().getSession().getPeerPrincipal()).getName())
.isEqualTo("CN=" + TestUtils.TEST_SERVER_HOST);
socket.close();
serverSocket.get().close();
}

@Test
public void sslSocketFactoryFrom_tls_mtls_passwordUnsupported() throws Exception {
ChannelCredentials creds = TlsChannelCredentials.newBuilder()
.keyManager(TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"))
.keyManager(
TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"), "password")
.build();
OkHttpChannelBuilder.SslSocketFactoryResult result =
OkHttpChannelBuilder.sslSocketFactoryFrom(creds);
Expand Down

0 comments on commit 2cb2fe5

Please sign in to comment.