diff --git a/Sources/Redis/RedisConfiguration.swift b/Sources/Redis/RedisConfiguration.swift index 7f76697..e590a20 100644 --- a/Sources/Redis/RedisConfiguration.swift +++ b/Sources/Redis/RedisConfiguration.swift @@ -1,4 +1,6 @@ @_exported import struct Foundation.URL +import NIOSSL +import NIOPosix @_exported import struct Logging.Logger @_exported import struct NIO.TimeAmount import enum NIO.SocketAddress @@ -11,6 +13,8 @@ public struct RedisConfiguration { public var password: String? public var database: Int? public var pool: PoolOptions + public var tlsConfiguration: TLSConfiguration? + public var tlsHostname: String? public struct PoolOptions { public var maximumConnectionCount: RedisConnectionPoolSize @@ -34,23 +38,32 @@ public struct RedisConfiguration { } } - public init(url string: String, pool: PoolOptions = .init()) throws { + public init(url string: String, tlsConfiguration: TLSConfiguration? = nil, pool: PoolOptions = .init()) throws { guard let url = URL(string: string) else { throw ValidationError.invalidURLString } - try self.init(url: url, pool: pool) + try self.init(url: url, tlsConfiguration: tlsConfiguration, pool: pool) } - public init(url: URL, pool: PoolOptions = .init()) throws { + public init(url: URL, tlsConfiguration: TLSConfiguration? = nil, pool: PoolOptions = .init()) throws { guard let scheme = url.scheme, !scheme.isEmpty else { throw ValidationError.missingURLScheme } - guard scheme == "redis" else { throw ValidationError.invalidURLScheme } + guard scheme == "redis" || scheme == "rediss" else { throw ValidationError.invalidURLScheme } guard let host = url.host, !host.isEmpty else { throw ValidationError.missingURLHost } + let defaultTLSConfig: TLSConfiguration? + if scheme == "rediss" { + // If we're given a 'rediss' URL, make sure we have at least a default TLS config. + defaultTLSConfig = tlsConfiguration ?? .makeClientConfiguration() + } else { + defaultTLSConfig = tlsConfiguration + } + try self.init( hostname: host, port: url.port ?? RedisConnection.Configuration.defaultPort, password: url.password, + tlsConfiguration: defaultTLSConfig, database: Int(url.lastPathComponent), pool: pool ) @@ -60,6 +73,7 @@ public struct RedisConfiguration { hostname: String, port: Int = RedisConnection.Configuration.defaultPort, password: String? = nil, + tlsConfiguration: TLSConfiguration? = nil, database: Int? = nil, pool: PoolOptions = .init() ) throws { @@ -68,6 +82,8 @@ public struct RedisConfiguration { try self.init( serverAddresses: [.makeAddressResolvingHost(hostname, port: port)], password: password, + tlsConfiguration: tlsConfiguration, + tlsHostname: hostname, database: database, pool: pool ) @@ -76,18 +92,22 @@ public struct RedisConfiguration { public init( serverAddresses: [SocketAddress], password: String? = nil, + tlsConfiguration: TLSConfiguration? = nil, + tlsHostname: String? = nil, database: Int? = nil, pool: PoolOptions = .init() ) throws { self.serverAddresses = serverAddresses self.password = password + self.tlsConfiguration = tlsConfiguration + self.tlsHostname = tlsHostname self.database = database self.pool = pool } } extension RedisConnectionPool.Configuration { - internal init(_ config: RedisConfiguration, defaultLogger: Logger) { + internal init(_ config: RedisConfiguration, defaultLogger: Logger, customClient: ClientBootstrap?) { self.init( initialServerConnectionAddresses: config.serverAddresses, maximumConnectionCount: config.pool.maximumConnectionCount, @@ -95,7 +115,7 @@ extension RedisConnectionPool.Configuration { connectionInitialDatabase: config.database, connectionPassword: config.password, connectionDefaultLogger: defaultLogger, - tcpClient: nil + tcpClient: customClient ), minimumConnectionCount: config.pool.minimumConnectionCount, connectionBackoffFactor: config.pool.connectionBackoffFactor, diff --git a/Sources/Redis/RedisStorage.swift b/Sources/Redis/RedisStorage.swift index 2c28183..0f141b7 100644 --- a/Sources/Redis/RedisStorage.swift +++ b/Sources/Redis/RedisStorage.swift @@ -1,5 +1,8 @@ import Vapor import NIOConcurrencyHelpers +import NIOCore +import NIOPosix +import NIOSSL extension Application { private struct RedisStorageKey: StorageKey { @@ -75,8 +78,28 @@ extension RedisStorage { let newKey: PoolKey = PoolKey(eventLoopKey: eventLoop.key, redisID: redisID) + let redisTLSClient: ClientBootstrap? = { + guard let tlsConfig = configuration.tlsConfiguration, + let tlsHost = configuration.tlsHostname else { return nil } + + return ClientBootstrap(group: eventLoop) + .channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .channelInitializer { channel in + do { + let sslContext = try NIOSSLContext(configuration: tlsConfig) + return EventLoopFuture.andAllSucceed([ + channel.pipeline.addHandler(try NIOSSLClientHandler(context: sslContext, + serverHostname: tlsHost)), + channel.pipeline.addBaseRedisHandlers() + ], on: channel.eventLoop) + } catch { + return channel.eventLoop.makeFailedFuture(error) + } + } + }() + let newPool = RedisConnectionPool( - configuration: .init(configuration, defaultLogger: application.logger), + configuration: .init(configuration, defaultLogger: application.logger, customClient: redisTLSClient), boundEventLoop: eventLoop) newPools[newKey] = newPool