Skip to content

Commit

Permalink
Fix EventStream to properly handle utf-8 (#216)
Browse files Browse the repository at this point in the history
  • Loading branch information
waahm7 authored Jan 9, 2024
1 parent e9aca22 commit 4aebaa6
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 80 deletions.
2 changes: 1 addition & 1 deletion Source/AwsCommonRuntimeKit/crt/AWSString.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ final class AWSString {
let rawValue: UnsafeMutablePointer<aws_string>

init(_ str: String) {
self.rawValue = aws_string_new_from_array(allocator.rawValue, str, str.count)
self.rawValue = aws_string_new_from_array(allocator.rawValue, str, str.utf8.count)
}

var count: Int {
Expand Down
5 changes: 4 additions & 1 deletion Source/AwsCommonRuntimeKit/crt/Utilities.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ extension String {
func withByteCursor<Result>(_ body: (aws_byte_cursor) -> Result
) -> Result {
return self.withCString { arg1C in
return body(aws_byte_cursor_from_c_str(arg1C))
return body(aws_byte_cursor_from_array(arg1C, self.utf8.count))
}
}

Expand Down Expand Up @@ -94,6 +94,9 @@ extension aws_byte_buf {
}

func toData() -> Data {
if self.len == 0 {
return Data()
}
return Data(bytes: self.buffer, count: self.len)
}
}
Expand Down
152 changes: 76 additions & 76 deletions Source/AwsCommonRuntimeKit/event-stream/EventStreamMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,90 +47,90 @@ public struct EventStreamMessage {

extension EventStreamMessage {
func addHeader(header: EventStreamHeader, rawHeaders: UnsafeMutablePointer<aws_array_list>) throws {
if header.name.count > EventStreamHeader.maxNameLength {
let headerNameLength = header.name.utf8.count
if headerNameLength > EventStreamHeader.maxNameLength {
throw CommonRunTimeError.crtError(
.init(
code: AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN.rawValue))
}
let addCHeader: () throws -> Int32 = {
return try header.name.withCString { headerName in
switch header.value {
case .bool(let value):
return aws_event_stream_add_bool_header(
rawHeaders,
headerName,
UInt8(header.name.count),
Int8(value.uintValue))
case .byte(let value):
return aws_event_stream_add_byte_header(
rawHeaders,
headerName,
UInt8(header.name.count),
value)
case .int16(let value):
return aws_event_stream_add_int16_header(
rawHeaders,
headerName,
UInt8(header.name.count),
value)
case .int32(let value):
return aws_event_stream_add_int32_header(
let headerNameLength = UInt8(headerNameLength)
switch header.value {
case .bool(let value):
return aws_event_stream_add_bool_header(
rawHeaders,
header.name,
headerNameLength,
Int8(value.uintValue))
case .byte(let value):
return aws_event_stream_add_byte_header(
rawHeaders,
header.name,
headerNameLength,
value)
case .int16(let value):
return aws_event_stream_add_int16_header(
rawHeaders,
header.name,
headerNameLength,
value)
case .int32(let value):
return aws_event_stream_add_int32_header(
rawHeaders,
header.name,
headerNameLength,
value)
case .int64(let value):
return aws_event_stream_add_int64_header(
rawHeaders,
header.name,
headerNameLength,
value)
case .byteBuf(var value):
let valueCount = value.count
if valueCount > EventStreamHeader.maxValueLength {
throw CommonRunTimeError.crtError(
.init(
code: AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN.rawValue))
}
return value.withUnsafeMutableBytes {
let bytes = $0.bindMemory(to: UInt8.self).baseAddress!
return aws_event_stream_add_bytebuf_header(
rawHeaders,
headerName,
UInt8(header.name.count),
value)
case .int64(let value):
return aws_event_stream_add_int64_header(
header.name,
headerNameLength,
bytes,
UInt16(valueCount),
1)
}
case .string(let value):
let valueCount = value.utf8.count
if valueCount > EventStreamHeader.maxValueLength {
throw CommonRunTimeError.crtError(
.init(
code: AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN.rawValue))
}
return aws_event_stream_add_string_header(
rawHeaders,
headerName,
UInt8(header.name.count),
value)
case .byteBuf(var value):
if value.count > EventStreamHeader.maxValueLength {
throw CommonRunTimeError.crtError(
.init(
code: AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN.rawValue))
}
return value.withUnsafeMutableBytes {
let bytes = $0.bindMemory(to: UInt8.self).baseAddress!
return aws_event_stream_add_bytebuf_header(
rawHeaders,
headerName,
UInt8(header.name.count),
bytes,
UInt16($0.count),
1)
}
case .string(let value):
if value.count > EventStreamHeader.maxValueLength {
throw CommonRunTimeError.crtError(
.init(
code: AWS_ERROR_EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN.rawValue))
}
return value.withCString {
aws_event_stream_add_string_header(
rawHeaders,
headerName,
UInt8(header.name.count),
$0,
UInt16(value.count),
1)
}
case .timestamp(let value):
return aws_event_stream_add_timestamp_header(
header.name,
headerNameLength,
value,
UInt16(valueCount),
1)
case .timestamp(let value):
return aws_event_stream_add_timestamp_header(
rawHeaders,
header.name,
headerNameLength,
Int64(value.millisecondsSince1970))
case .uuid(let value):
return withUnsafeBytes(of: value) {
let address = $0.baseAddress?.assumingMemoryBound(to: UInt8.self)
return aws_event_stream_add_uuid_header(
rawHeaders,
headerName,
UInt8(header.name.count),
Int64(value.millisecondsSince1970))
case .uuid(let value):
return withUnsafeBytes(of: value) {
let address = $0.baseAddress?.assumingMemoryBound(to: UInt8.self)
return aws_event_stream_add_uuid_header(
rawHeaders,
headerName,
UInt8(header.name.count),
address)
}
header.name,
headerNameLength,
address)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ class EventStreamTests: XCBaseTestCase {
EventStreamHeader(name: "int32", value: .int32(value: 32)),
EventStreamHeader(name: "int64", value: .int32(value: 64)),
EventStreamHeader(name: "byteBuf", value: .byteBuf(value: "data".data(using: .utf8)!)),
EventStreamHeader(name: "emptyByteBuf", value: .byteBuf(value: Data())),
EventStreamHeader(name: "host", value: .string(value: "aws-crt-test-stuff.s3.amazonaws.com")),
EventStreamHeader(name: "host", value: .string(value: "aws-crt-test-stuff.s3.amazonaws.com")),
EventStreamHeader(name: "headerWithUtf8Character🧐", value: .string(value: "testValueWithEmoji🤯")),
EventStreamHeader(name: "bool", value: .bool(value: false)),
EventStreamHeader(name: "timestamp", value: .timestamp(value: Date(timeIntervalSinceNow: 10))),
EventStreamHeader(name: "uuid", value: .uuid(value: UUID(uuidString: "63318232-1C63-4D04-9A0C-6907F347704E")!)),
Expand All @@ -32,8 +34,8 @@ class EventStreamTests: XCBaseTestCase {
XCTFail("OnPayload callback is triggered unexpectedly.")
},
onPreludeReceived: { totalLength, headersLength in
XCTAssertEqual(totalLength, 210)
XCTAssertEqual(headersLength, 194)
XCTAssertEqual(totalLength, 279)
XCTAssertEqual(headersLength, 263)
},
onHeaderReceived: { header in
decodedHeaders.append(header)
Expand Down
7 changes: 7 additions & 0 deletions Test/AwsCommonRuntimeKitTests/http/HTTPTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ class HTTPTests: HTTPClientTestFixture {
_ = try await sendHTTPRequest(method: "GET", endpoint: host, path: getPath, connectionManager: connectionManager)
_ = try await sendHTTPRequest(method: "GET", endpoint: host, path: "/delete", expectedStatus: 404, connectionManager: connectionManager)
}

func testGetHTTPSRequestWithUtf8Header() async throws {
let connectionManager = try await getHttpConnectionManager(endpoint: host, ssh: true, port: 443)
let utf8Header = HTTPHeader(name: "TestHeader", value: "TestValueWithEmoji🤯")
let headers = try await sendHTTPRequest(method: "GET", endpoint: host, path: "/response-headers?\(utf8Header.name)=\(utf8Header.value)", connectionManager: connectionManager).headers
XCTAssertTrue(headers.contains(where: {$0.name == utf8Header.name && $0.value==utf8Header.value}))
}

func testGetHTTPRequest() async throws {
let connectionManager = try await getHttpConnectionManager(endpoint: host, ssh: false, port: 80)
Expand Down

0 comments on commit 4aebaa6

Please sign in to comment.