diff --git a/Sources/NIOHTTP1/HTTPHeaderValidator.swift b/Sources/NIOHTTP1/HTTPHeaderValidator.swift index 11eeb8b304..5f5e20d938 100644 --- a/Sources/NIOHTTP1/HTTPHeaderValidator.swift +++ b/Sources/NIOHTTP1/HTTPHeaderValidator.swift @@ -71,9 +71,16 @@ public final class NIOHTTPResponseHeadersValidator: ChannelOutboundHandler, Remo } private var state: State + private let sendResponseOnInvalidHeader: Bool public init() { self.state = .validating + self.sendResponseOnInvalidHeader = false + } + + public init(pipelineConfiguration: ChannelPipeline.SynchronousOperations.Configuration) { + self.state = .validating + self.sendResponseOnInvalidHeader = pipelineConfiguration.headerValidationResponse } public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { @@ -82,6 +89,14 @@ public final class NIOHTTPResponseHeadersValidator: ChannelOutboundHandler, Remo if head.headers.areValidToSend { context.write(data, promise: promise) } else { + // We won't write another header since we drop them going forward to write + // out a response if configured to do so + if self.sendResponseOnInvalidHeader { + let headers = HTTPHeaders([("Connection", "close"), ("Content-Length", "0")]) + let head = HTTPResponseHead(version: .http1_1, status: .internalServerError, headers: headers) + context.write(Self.wrapOutboundOut(.head(head)), promise: nil) + context.writeAndFlush(Self.wrapOutboundOut(.end(nil)), promise: nil) + } self.state = .dropping promise?.fail(HTTPParserError.invalidHeaderToken) context.fireErrorCaught(HTTPParserError.invalidHeaderToken) diff --git a/Sources/NIOHTTP1/HTTPPipelineSetup.swift b/Sources/NIOHTTP1/HTTPPipelineSetup.swift index a7645377d6..0fdc45a2a2 100644 --- a/Sources/NIOHTTP1/HTTPPipelineSetup.swift +++ b/Sources/NIOHTTP1/HTTPPipelineSetup.swift @@ -913,13 +913,67 @@ extension ChannelPipeline.SynchronousOperations { ) } + /// Configure a `ChannelPipeline` for use as a HTTP server. + /// + /// This function knows how to set up all first-party HTTP channel handlers appropriately + /// for server use. It supports the following features: + /// + /// 1. Providing assistance handling clients that pipeline HTTP requests, using the + /// `HTTPServerPipelineHandler`. + /// 2. Supporting HTTP upgrade, using the `HTTPServerUpgradeHandler`. + /// 3. Providing assistance handling protocol errors. + /// 4. Validating outbound header fields to protect against response splitting attacks. + /// 5. Specifying whether the header validation should return a response + /// + /// This method will likely be extended in future with more support for other first-party + /// features. + /// + /// - important: This **must** be called on the Channel's event loop. + /// - Parameters: + /// - position: Where in the pipeline to add the HTTP server handlers, defaults to `.last`. + /// - pipelining: Whether to provide assistance handling HTTP clients that pipeline + /// their requests. Defaults to `true`. If `false`, users will need to handle + /// clients that pipeline themselves. + /// - upgrade: Whether to add a `HTTPServerUpgradeHandler` to the pipeline, configured for + /// HTTP upgrade. Defaults to `nil`, which will not add the handler to the pipeline. If + /// provided should be a tuple of an array of `HTTPServerProtocolUpgrader` and the upgrade + /// completion handler. See the documentation on `HTTPServerUpgradeHandler` for more + /// details. + /// - errorHandling: Whether to provide assistance handling protocol errors (e.g. + /// failure to parse the HTTP request) by sending 400 errors. Defaults to `true`. + /// - headerValidation: Whether to validate outbound request headers to confirm that they meet + /// spec compliance. Defaults to `true`. + /// - encoderConfiguration: The configuration for the ``HTTPRequestEncoder``. + /// - configuration: Confguration for setting up for the pipeline. Provides additional options + /// for configuring the pipeline. + /// - Throws: If the pipeline could not be configured. + public func configureHTTPServerPipeline( + position: ChannelPipeline.SynchronousOperations.Position = .last, + withPipeliningAssistance pipelining: Bool = true, + withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil, + withErrorHandling errorHandling: Bool = true, + withOutboundHeaderValidation headerValidation: Bool = true, + withEncoderConfiguration encoderConfiguration: HTTPResponseEncoder.Configuration = .init(), + withConfiguration configuration: Configuration + ) throws { + try self._configureHTTPServerPipeline( + position: position, + withPipeliningAssistance: pipelining, + withServerUpgrade: upgrade, + withErrorHandling: errorHandling, + withOutboundHeaderValidation: headerValidation, + configuration: configuration + ) + } + private func _configureHTTPServerPipeline( position: ChannelPipeline.SynchronousOperations.Position = .last, withPipeliningAssistance pipelining: Bool = true, withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil, withErrorHandling errorHandling: Bool = true, withOutboundHeaderValidation headerValidation: Bool = true, - withEncoderConfiguration encoderConfiguration: HTTPResponseEncoder.Configuration = .init() + withEncoderConfiguration encoderConfiguration: HTTPResponseEncoder.Configuration = .init(), + configuration: Configuration = .init(), ) throws { self.eventLoop.assertInEventLoop() @@ -933,7 +987,7 @@ extension ChannelPipeline.SynchronousOperations { } if headerValidation { - handlers.append(NIOHTTPResponseHeadersValidator()) + handlers.append(NIOHTTPResponseHeadersValidator(pipelineConfiguration: configuration)) } if errorHandling { @@ -952,4 +1006,14 @@ extension ChannelPipeline.SynchronousOperations { try self.addHandlers(handlers, position: position) } + + /// Configuration for setting up an HTTP client pipeline. + public struct Configuration { + /// Whether or not a response is returned when the header validation fails. + public var headerValidationResponse: Bool + + public init() { + self.headerValidationResponse = false + } + } } diff --git a/Tests/NIOHTTP1Tests/HTTPHeaderValidationTests.swift b/Tests/NIOHTTP1Tests/HTTPHeaderValidationTests.swift index 346a64a859..4d0830bff3 100644 --- a/Tests/NIOHTTP1Tests/HTTPHeaderValidationTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPHeaderValidationTests.swift @@ -635,6 +635,61 @@ final class HTTPHeaderValidationTests: XCTestCase { XCTAssertEqual(maybeReceivedHeadBytes, toleratedRequestBytes) XCTAssertEqual(maybeReceivedTrailerBytes, toleratedTrailerBytes) } + + func testBadRequestResponseIsReturnedIfHeadersInvalidAndConfiguredToDoSo() throws { + let channel = EmbeddedChannel() + var pipelineConfig = ChannelPipeline.SynchronousOperations.Configuration() + pipelineConfig.headerValidationResponse = true + try channel.pipeline.syncOperations.configureHTTPServerPipeline(withConfiguration: pipelineConfig) + try channel.primeForResponse() + + func assertReadHead(from channel: EmbeddedChannel) throws { + if case .head = try channel.readInbound(as: HTTPServerRequestPart.self) { + () + } else { + XCTFail("Expected 'head'") + } + } + + func assertReadEnd(from channel: EmbeddedChannel) throws { + if case .end = try channel.readInbound(as: HTTPServerRequestPart.self) { + () + } else { + XCTFail("Expected 'end'") + } + } + + // Read the first request. + try assertReadHead(from: channel) + try assertReadEnd(from: channel) + XCTAssertNil(try channel.readInbound(as: HTTPServerRequestPart.self)) + + // Respond with bad headers; they should cause an error and result in the rest of the + // response being dropped, but a fallback response being sent + let head = HTTPResponseHead(version: .http1_1, status: .ok, headers: [":pseudo-header": "not-here"]) + XCTAssertThrowsError(try channel.writeOutbound(HTTPServerResponsePart.head(head))) + + // We expect exactly one ByteBuffer in the output. + guard var written = try channel.readOutbound(as: ByteBuffer.self) else { + XCTFail("No writes") + return + } + + XCTAssertNoThrow(XCTAssertNil(try channel.readOutbound())) + + // Check the response. + assertResponseIs( + response: written.readString(length: written.readableBytes)!, + expectedResponseLine: "HTTP/1.1 500 Internal Server Error", + expectedResponseHeaders: ["Connection: close", "Content-Length: 0"] + ) + XCTAssertThrowsError(try channel.writeOutbound(HTTPServerResponsePart.body(.byteBuffer(ByteBuffer())))) + XCTAssertNil(try channel.readOutbound(as: ByteBuffer.self)) + XCTAssertThrowsError(try channel.writeOutbound(HTTPServerResponsePart.end(nil))) + XCTAssertNil(try channel.readOutbound(as: ByteBuffer.self)) + + _ = try? channel.finish() + } } extension EmbeddedChannel {