From 41fdf6a6e01b7ed28fee0f93647e7d6bdd502688 Mon Sep 17 00:00:00 2001 From: John Szumski Date: Tue, 14 Mar 2023 09:53:15 -0400 Subject: [PATCH] Add the ability to encode and decode a size delimited message collection in Swift. --- .../swift/ProtoCodable/ProtoDecoder.swift | 53 ++++++++++++++++++- .../swift/ProtoCodable/ProtoEncoder.swift | 34 +++++++++++- .../main/swift/ProtoCodable/WriteBuffer.swift | 8 +++ .../src/test/swift/ProtoDecoderTests.swift | 7 +++ .../src/test/swift/ProtoEncoderTests.swift | 8 +++ .../src/test/swift/RoundTripTests.swift | 14 +++++ .../src/test/swift/WriteBufferTests.swift | 7 +++ 7 files changed, 128 insertions(+), 3 deletions(-) diff --git a/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoDecoder.swift b/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoDecoder.swift index 25d26771eb..e59fcdc3de 100644 --- a/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoDecoder.swift +++ b/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoDecoder.swift @@ -125,6 +125,12 @@ public final class ProtoDecoder { // MARK: - Public Methods + /// Decodes the provided data into an instance of the requested type. + /// + /// - Parameters: + /// - type: the type to decode + /// - data: the serialized data for the message + /// - Returns: the decoded message public func decode(_ type: T.Type, from data: Data) throws -> T { var value: T? try data.withUnsafeBytes { buffer in @@ -148,5 +154,50 @@ public final class ProtoDecoder { return unwrappedValue } -} + /// Decodes the provided size-delimited data into instances of the requested type. + /// + /// A size-delimited collection of messages is a sequence of varint + message pairs + /// where the varint indicates the size of the subsequent message. + /// + /// - Parameters: + /// - type: the type to decode + /// - data: the serialized size-delimited data for the messages + /// - Returns: an array of the decoded messages + public func decodeSizeDelimited(_ type: T.Type, from data: Data) throws -> [T] { + var values: [T] = [] + + try data.withUnsafeBytes { buffer in + // Handle the empty-data case. + guard let baseAddress = buffer.baseAddress, buffer.count > 0 else { + return + } + + let fullBuffer = ReadBuffer( + storage: baseAddress.bindMemory(to: UInt8.self, capacity: buffer.count), + count: buffer.count + ) + while fullBuffer.isDataRemaining, let size = try? fullBuffer.readVarint64() { + if size == 0 { break } + + let messageBuffer = ReadBuffer( + storage: fullBuffer.pointer, + count: Int(size) + ) + + let reader = ProtoReader( + buffer: messageBuffer, + enumDecodingStrategy: enumDecodingStrategy + ) + + values.append(try reader.decode(type)) + + // Advance the buffer before reading the next item in the stream + _ = try fullBuffer.readBuffer(count: Int(size)) + } + } + + return values + } + +} diff --git a/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoEncoder.swift b/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoEncoder.swift index b2fda57a6d..55fd0d143a 100644 --- a/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoEncoder.swift +++ b/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoEncoder.swift @@ -82,13 +82,43 @@ public final class ProtoEncoder { let writer = ProtoWriter( data: .init(capacity: structSize), - outputFormatting: [], + outputFormatting: outputFormatting, rootMessageProtoSyntax: T.self.protoSyntax ?? .proto2 ) - writer.outputFormatting = outputFormatting + try value.encode(to: writer) return Data(writer.buffer, copyBytes: false) } + public func encodeSizeDelimited(_ values: [T]) throws -> Data { + // Use the size of the struct as an initial estimate for the space needed. + let structSize = MemoryLayout.size(ofValue: T.self) + + // Reserve space for the largest varint size + let varintSize = 8 + + let fullBuffer = WriteBuffer(capacity: (structSize + varintSize) * values.count) + + for value in values { + let writer = ProtoWriter( + data: .init(), + outputFormatting: outputFormatting, + rootMessageProtoSyntax: T.self.protoSyntax ?? .proto2 + ) + + try value.encode(to: writer) + + if writer.buffer.count == 0 { + continue + } + + // write this value's size + contents to the main buffer + fullBuffer.writeVarint(UInt64(writer.buffer.count), at: fullBuffer.count) + fullBuffer.append(writer.buffer) + } + + return Data(fullBuffer, copyBytes: false) + } + } diff --git a/wire-runtime-swift/src/main/swift/ProtoCodable/WriteBuffer.swift b/wire-runtime-swift/src/main/swift/ProtoCodable/WriteBuffer.swift index c2b277e017..7378555bf7 100644 --- a/wire-runtime-swift/src/main/swift/ProtoCodable/WriteBuffer.swift +++ b/wire-runtime-swift/src/main/swift/ProtoCodable/WriteBuffer.swift @@ -50,6 +50,8 @@ final class WriteBuffer { // MARK: - Public Methods func append(_ data: Data) { + guard !data.isEmpty else { return } + expandIfNeeded(adding: data.count) data.copyBytes(to: storage.advanced(by: count), count: data.count) @@ -64,6 +66,8 @@ final class WriteBuffer { } func append(_ value: [UInt8]) { + guard !value.isEmpty else { return } + expandIfNeeded(adding: value.count) for byte in value { @@ -74,6 +78,8 @@ final class WriteBuffer { func append(_ value: WriteBuffer) { precondition(value !== self) + guard value.count > 0 else { return } + expandIfNeeded(adding: value.count) memcpy(storage.advanced(by: count), value.storage, value.count) @@ -81,6 +87,8 @@ final class WriteBuffer { } func append(_ value: UnsafeRawBufferPointer) { + guard value.count > 0 else { return } + expandIfNeeded(adding: value.count) memcpy(storage.advanced(by: count), value.baseAddress, value.count) diff --git a/wire-runtime-swift/src/test/swift/ProtoDecoderTests.swift b/wire-runtime-swift/src/test/swift/ProtoDecoderTests.swift index 299e629eaf..a0cb376f2e 100644 --- a/wire-runtime-swift/src/test/swift/ProtoDecoderTests.swift +++ b/wire-runtime-swift/src/test/swift/ProtoDecoderTests.swift @@ -26,6 +26,13 @@ final class ProtoDecoderTests: XCTestCase { XCTAssertEqual(object, SimpleOptional2()) } + func testDecodeEmptySizeDelimitedData() throws { + let decoder = ProtoDecoder() + let object = try decoder.decodeSizeDelimited(SimpleOptional2.self, from: Data()) + + XCTAssertEqual(object, []) + } + func testDecodeEmptyDataTwice() throws { let decoder = ProtoDecoder() // The empty message case is optimized to reuse objects, so make sure diff --git a/wire-runtime-swift/src/test/swift/ProtoEncoderTests.swift b/wire-runtime-swift/src/test/swift/ProtoEncoderTests.swift index 456d885dbe..c02ae05734 100644 --- a/wire-runtime-swift/src/test/swift/ProtoEncoderTests.swift +++ b/wire-runtime-swift/src/test/swift/ProtoEncoderTests.swift @@ -44,4 +44,12 @@ final class ProtoEncoderTests: XCTestCase { XCTAssertEqual(jsonString, "{}") } + + func testEncodeEmptySizeDelimitedMessage() throws { + let object = EmptyMessage() + let encoder = ProtoEncoder() + let data = try encoder.encodeSizeDelimited([object]) + + XCTAssertEqual(data, Data()) + } } diff --git a/wire-runtime-swift/src/test/swift/RoundTripTests.swift b/wire-runtime-swift/src/test/swift/RoundTripTests.swift index bcd835b276..079a8a6642 100644 --- a/wire-runtime-swift/src/test/swift/RoundTripTests.swift +++ b/wire-runtime-swift/src/test/swift/RoundTripTests.swift @@ -60,4 +60,18 @@ final class RoundTripTests: XCTestCase { XCTAssertEqual(decodedEmpty, empty) } + func testSizeDelimited() throws { + let values = [ + Person3(name: "John Doe", id: 123), + Person3(name: "Jane Doe", id: 456, email: "jdoe@example.com") + ] + + let encoder = ProtoEncoder() + let data = try encoder.encodeSizeDelimited(values) + + let decoder = ProtoDecoder() + let decodedValues = try decoder.decodeSizeDelimited(Person3.self, from: data) + + XCTAssertEqual(decodedValues, values) + } } diff --git a/wire-runtime-swift/src/test/swift/WriteBufferTests.swift b/wire-runtime-swift/src/test/swift/WriteBufferTests.swift index 096aa9419e..08c4abaf03 100644 --- a/wire-runtime-swift/src/test/swift/WriteBufferTests.swift +++ b/wire-runtime-swift/src/test/swift/WriteBufferTests.swift @@ -60,4 +60,11 @@ final class WriteBufferTests: XCTestCase { XCTAssertEqual(Data(buffer, copyBytes: true), Data(hexEncoded: "0011")) } + func testAppendEmptyFirst() { + let buffer = WriteBuffer() + buffer.append(Data()) + + XCTAssertEqual(Data(buffer, copyBytes: true), Data()) + } + }