Skip to content

Commit

Permalink
Add support for TLS-backed connections to Redis (#205)
Browse files Browse the repository at this point in the history
* Allow TLS for Redis connections

* Tidy up TLS config logic
  • Loading branch information
iKenndac authored Feb 1, 2023
1 parent 9cf334a commit fee95ab
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 7 deletions.
32 changes: 26 additions & 6 deletions Sources/Redis/RedisConfiguration.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
@_exported import struct Foundation.URL
import NIOSSL
import NIOPosix
@_exported import struct Logging.Logger
@_exported import struct NIO.TimeAmount
import enum NIO.SocketAddress
Expand All @@ -11,6 +13,8 @@ public struct RedisConfiguration {
public var password: String?
public var database: Int?
public var pool: PoolOptions
public var tlsConfiguration: TLSConfiguration?
public var tlsHostname: String?

public struct PoolOptions {
public var maximumConnectionCount: RedisConnectionPoolSize
Expand All @@ -34,23 +38,32 @@ public struct RedisConfiguration {
}
}

public init(url string: String, pool: PoolOptions = .init()) throws {
public init(url string: String, tlsConfiguration: TLSConfiguration? = nil, pool: PoolOptions = .init()) throws {
guard let url = URL(string: string) else { throw ValidationError.invalidURLString }
try self.init(url: url, pool: pool)
try self.init(url: url, tlsConfiguration: tlsConfiguration, pool: pool)
}

public init(url: URL, pool: PoolOptions = .init()) throws {
public init(url: URL, tlsConfiguration: TLSConfiguration? = nil, pool: PoolOptions = .init()) throws {
guard
let scheme = url.scheme,
!scheme.isEmpty
else { throw ValidationError.missingURLScheme }
guard scheme == "redis" else { throw ValidationError.invalidURLScheme }
guard scheme == "redis" || scheme == "rediss" else { throw ValidationError.invalidURLScheme }
guard let host = url.host, !host.isEmpty else { throw ValidationError.missingURLHost }

let defaultTLSConfig: TLSConfiguration?
if scheme == "rediss" {
// If we're given a 'rediss' URL, make sure we have at least a default TLS config.
defaultTLSConfig = tlsConfiguration ?? .makeClientConfiguration()
} else {
defaultTLSConfig = tlsConfiguration
}

try self.init(
hostname: host,
port: url.port ?? RedisConnection.Configuration.defaultPort,
password: url.password,
tlsConfiguration: defaultTLSConfig,
database: Int(url.lastPathComponent),
pool: pool
)
Expand All @@ -60,6 +73,7 @@ public struct RedisConfiguration {
hostname: String,
port: Int = RedisConnection.Configuration.defaultPort,
password: String? = nil,
tlsConfiguration: TLSConfiguration? = nil,
database: Int? = nil,
pool: PoolOptions = .init()
) throws {
Expand All @@ -68,6 +82,8 @@ public struct RedisConfiguration {
try self.init(
serverAddresses: [.makeAddressResolvingHost(hostname, port: port)],
password: password,
tlsConfiguration: tlsConfiguration,
tlsHostname: hostname,
database: database,
pool: pool
)
Expand All @@ -76,26 +92,30 @@ public struct RedisConfiguration {
public init(
serverAddresses: [SocketAddress],
password: String? = nil,
tlsConfiguration: TLSConfiguration? = nil,
tlsHostname: String? = nil,
database: Int? = nil,
pool: PoolOptions = .init()
) throws {
self.serverAddresses = serverAddresses
self.password = password
self.tlsConfiguration = tlsConfiguration
self.tlsHostname = tlsHostname
self.database = database
self.pool = pool
}
}

extension RedisConnectionPool.Configuration {
internal init(_ config: RedisConfiguration, defaultLogger: Logger) {
internal init(_ config: RedisConfiguration, defaultLogger: Logger, customClient: ClientBootstrap?) {
self.init(
initialServerConnectionAddresses: config.serverAddresses,
maximumConnectionCount: config.pool.maximumConnectionCount,
connectionFactoryConfiguration: .init(
connectionInitialDatabase: config.database,
connectionPassword: config.password,
connectionDefaultLogger: defaultLogger,
tcpClient: nil
tcpClient: customClient
),
minimumConnectionCount: config.pool.minimumConnectionCount,
connectionBackoffFactor: config.pool.connectionBackoffFactor,
Expand Down
25 changes: 24 additions & 1 deletion Sources/Redis/RedisStorage.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import Vapor
import NIOConcurrencyHelpers
import NIOCore
import NIOPosix
import NIOSSL

extension Application {
private struct RedisStorageKey: StorageKey {
Expand Down Expand Up @@ -75,8 +78,28 @@ extension RedisStorage {

let newKey: PoolKey = PoolKey(eventLoopKey: eventLoop.key, redisID: redisID)

let redisTLSClient: ClientBootstrap? = {
guard let tlsConfig = configuration.tlsConfiguration,
let tlsHost = configuration.tlsHostname else { return nil }

return ClientBootstrap(group: eventLoop)
.channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.channelInitializer { channel in
do {
let sslContext = try NIOSSLContext(configuration: tlsConfig)
return EventLoopFuture.andAllSucceed([
channel.pipeline.addHandler(try NIOSSLClientHandler(context: sslContext,
serverHostname: tlsHost)),
channel.pipeline.addBaseRedisHandlers()
], on: channel.eventLoop)
} catch {
return channel.eventLoop.makeFailedFuture(error)
}
}
}()

let newPool = RedisConnectionPool(
configuration: .init(configuration, defaultLogger: application.logger),
configuration: .init(configuration, defaultLogger: application.logger, customClient: redisTLSClient),
boundEventLoop: eventLoop)

newPools[newKey] = newPool
Expand Down

0 comments on commit fee95ab

Please sign in to comment.