From 9d764b634b9f12540671ca3777ac0fe9afe533c7 Mon Sep 17 00:00:00 2001 From: John Szumski Date: Tue, 14 Mar 2023 09:51:38 -0400 Subject: [PATCH] Improve encoding and decoding of proto3 messages with default values. --- .../swift/ProtoCodable/ProtoCodable.swift | 15 ++++ .../swift/ProtoCodable/ProtoDecoder.swift | 3 + .../main/swift/ProtoCodable/ProtoWriter.swift | 15 ++++ .../main/swift/wellknowntypes/Duration.swift | 8 +- .../main/swift/wellknowntypes/Timestamp.swift | 8 +- wire-runtime-swift/src/test/proto/empty.proto | 16 ++++ .../src/test/swift/ProtoEncoderTests.swift | 2 +- .../src/test/swift/RoundTripTests.swift | 22 +++++ .../com/squareup/wire/swift/SwiftGenerator.kt | 84 ++++++++++++++++--- 9 files changed, 151 insertions(+), 22 deletions(-) diff --git a/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoCodable.swift b/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoCodable.swift index 61aec3f68c..21c5dffbeb 100644 --- a/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoCodable.swift +++ b/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoCodable.swift @@ -105,3 +105,18 @@ extension ProtoDecodable { } } + +extension ProtoEnum where Self : RawRepresentable, RawValue == UInt32 { + + /** + A convenience function used with enum fields that throws an error if the field is null + and its default value can't be used instead. + */ + public static func defaultIfMissing(_ value: Self?) throws -> Self { + guard let value = value ?? Self(rawValue: 0) else { + throw ProtoDecoder.Error.missingEnumDefaultValue(type: Self.self) + } + return value + } + +} diff --git a/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoDecoder.swift b/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoDecoder.swift index cf51ef1014..25d26771eb 100644 --- a/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoDecoder.swift +++ b/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoDecoder.swift @@ -51,6 +51,7 @@ public final class ProtoDecoder { case mapEntryWithoutKey(value: Any?) case mapEntryWithoutValue(key: Any) case messageWithoutLength + case missingEnumDefaultValue(type: Any.Type) case missingRequiredField(typeName: String, fieldName: String) case recursionLimitExceeded case unexpectedEndOfData @@ -80,6 +81,8 @@ public final class ProtoDecoder { return "Map entry with \(key) did not include a value." case .messageWithoutLength: return "Attempting to decode a message without first decoding the length of that message." + case let .missingEnumDefaultValue(type): + return "Could not assign a default value of 0 for enum type \(String(describing: type))" case let .missingRequiredField(typeName, fieldName): return "Required field \(fieldName) for type \(typeName) is not included in the message data." case let .boxedValueMissingField(type): diff --git a/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoWriter.swift b/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoWriter.swift index c91a4b632e..ce11473b5e 100644 --- a/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoWriter.swift +++ b/wire-runtime-swift/src/main/swift/ProtoCodable/ProtoWriter.swift @@ -145,6 +145,21 @@ public final class ProtoWriter { try value.encode(to: self) } + /** Encode a required `bool` field */ + public func encode(tag: UInt32, value: Bool) throws { + if value == false && isProto3 { return } + try encode(tag: tag, value: value as Bool?) + } + + /** Encode an optional `bool` field */ + public func encode(tag: UInt32, value: Bool?) throws { + guard let value = value else { return } + + let key = ProtoWriter.makeFieldKey(tag: tag, wireType: .varint) + writeVarint(key) + try value.encode(to: self) + } + /** Encode a required `int32`, `sfixed32`, or `sint32` field */ public func encode(tag: UInt32, value: Int32, encoding: ProtoIntEncoding = .variable) throws { // Don't encode default values if using proto3 syntax. diff --git a/wire-runtime-swift/src/main/swift/wellknowntypes/Duration.swift b/wire-runtime-swift/src/main/swift/wellknowntypes/Duration.swift index 9db67fc647..4d3912abce 100644 --- a/wire-runtime-swift/src/main/swift/wellknowntypes/Duration.swift +++ b/wire-runtime-swift/src/main/swift/wellknowntypes/Duration.swift @@ -92,8 +92,8 @@ extension Duration : ProtoMessage { extension Duration : Proto3Codable { public init(from reader: ProtoReader) throws { - var seconds: Int64? = nil - var nanos: Int32? = nil + var seconds: Int64 = 0 + var nanos: Int32 = 0 let token = try reader.beginMessage() while let tag = try reader.nextTag(token: token) { @@ -105,8 +105,8 @@ extension Duration : Proto3Codable { } self.unknownFields = try reader.endMessage(token: token) - self.seconds = try Duration.checkIfMissing(seconds, "seconds") - self.nanos = try Duration.checkIfMissing(nanos, "nanos") + self.seconds = seconds + self.nanos = nanos } public func encode(to writer: ProtoWriter) throws { diff --git a/wire-runtime-swift/src/main/swift/wellknowntypes/Timestamp.swift b/wire-runtime-swift/src/main/swift/wellknowntypes/Timestamp.swift index b455c343f8..09da7b1a5b 100644 --- a/wire-runtime-swift/src/main/swift/wellknowntypes/Timestamp.swift +++ b/wire-runtime-swift/src/main/swift/wellknowntypes/Timestamp.swift @@ -104,8 +104,8 @@ extension Timestamp : ProtoMessage { extension Timestamp : Proto3Codable { public init(from reader: ProtoReader) throws { - var seconds: Int64? = nil - var nanos: Int32? = nil + var seconds: Int64 = 0 + var nanos: Int32 = 0 let token = try reader.beginMessage() while let tag = try reader.nextTag(token: token) { @@ -117,8 +117,8 @@ extension Timestamp : Proto3Codable { } self.unknownFields = try reader.endMessage(token: token) - self.seconds = try Timestamp.checkIfMissing(seconds, "seconds") - self.nanos = try Timestamp.checkIfMissing(nanos, "nanos") + self.seconds = seconds + self.nanos = nanos } public func encode(to writer: ProtoWriter) throws { diff --git a/wire-runtime-swift/src/test/proto/empty.proto b/wire-runtime-swift/src/test/proto/empty.proto index a1837cd6ea..de96ebcfa7 100644 --- a/wire-runtime-swift/src/test/proto/empty.proto +++ b/wire-runtime-swift/src/test/proto/empty.proto @@ -20,5 +20,21 @@ message EmptyMessage { } message EmptyOmitted { + enum EmptyEnum { + UNKNOWN = 0; + OTHER = 1; + } + + message EmptyNested { + int32 nested = 1; + } + int32 numeric_value = 1; + string string_value = 2; + bytes bytes_value = 3; + bool bool_value = 4; + EmptyEnum enum_value = 5; + EmptyNested message_value = 6; + repeated string repeated_value = 7; + map map_value = 8; } diff --git a/wire-runtime-swift/src/test/swift/ProtoEncoderTests.swift b/wire-runtime-swift/src/test/swift/ProtoEncoderTests.swift index 1ac90898d4..456d885dbe 100644 --- a/wire-runtime-swift/src/test/swift/ProtoEncoderTests.swift +++ b/wire-runtime-swift/src/test/swift/ProtoEncoderTests.swift @@ -29,7 +29,7 @@ final class ProtoEncoderTests: XCTestCase { } func testEncodeEmptyProtoMessageWithIdentityValues() throws { - let object = EmptyOmitted(numeric_value: 0) + let object = EmptyOmitted(numeric_value: 0, string_value: "", bytes_value: .init(), bool_value: false, enum_value: .UNKNOWN) let encoder = ProtoEncoder() let data = try encoder.encode(object) diff --git a/wire-runtime-swift/src/test/swift/RoundTripTests.swift b/wire-runtime-swift/src/test/swift/RoundTripTests.swift index de581559b6..bcd835b276 100644 --- a/wire-runtime-swift/src/test/swift/RoundTripTests.swift +++ b/wire-runtime-swift/src/test/swift/RoundTripTests.swift @@ -38,4 +38,26 @@ final class RoundTripTests: XCTestCase { XCTAssertEqual(decodedPerson, person) } + // ensure that fields set to their identity value survive a roundtrip when omitted over the wire + func testProto3IdentityValues() throws { + let empty = EmptyOmitted( + numeric_value: 0, + string_value: "", + bytes_value: Data(), + bool_value: false, + enum_value: .UNKNOWN, + message_value: nil, + repeated_value: [], + map_value: [:] + ) + + let encoder = ProtoEncoder() + let data = try encoder.encode(empty) + + let decoder = ProtoDecoder() + let decodedEmpty = try decoder.decode(EmptyOmitted.self, from: data) + + XCTAssertEqual(decodedEmpty, empty) + } + } diff --git a/wire-swift-generator/src/main/java/com/squareup/wire/swift/SwiftGenerator.kt b/wire-swift-generator/src/main/java/com/squareup/wire/swift/SwiftGenerator.kt index 7327ac53f4..857920c5cb 100644 --- a/wire-swift-generator/src/main/java/com/squareup/wire/swift/SwiftGenerator.kt +++ b/wire-swift-generator/src/main/java/com/squareup/wire/swift/SwiftGenerator.kt @@ -132,6 +132,22 @@ class SwiftGenerator private constructor( else -> null } + // see https://protobuf.dev/programming-guides/proto3/#default + private val Field.proto3InitialValue: String + get() = when { + isMap -> "[:]" + isRepeated -> "[]" + isOptional -> "nil" + else -> when (typeName.makeNonOptional()) { + BOOL -> "false" + DOUBLE, FLOAT -> "0" + INT32, UINT32, INT64, UINT64 -> "0" + STRING -> """""""" // evaluates to the empty string + DATA -> ".init()" + else -> "nil" + } + } + private val Field.codableName: String? get() = jsonName?.takeIf { it != name } ?: camelCase(name).takeIf { it != name } @@ -192,6 +208,27 @@ class SwiftGenerator private constructor( private val MessageType.isHeapAllocated get() = fields.size + oneOfs.size >= 16 + /** + * Checks that every enum in a proto3 message contains a value with tag 0. + * + * @throws NoSuchElementException if the case doesn't exist + */ + @Throws(NoSuchElementException::class) + private fun validateProto3DefaultsExist(type: MessageType) { + if (type.syntax == PROTO_2) { return } + + // validate each enum field + type + .fields + .mapNotNull { schema.getType(it.type!!) as? EnumType } + .forEach { enum -> + // ensure that a 0 case exists + if (enum.constants.filter { it.tag == 0 }.isEmpty()) { + throw NoSuchElementException("Missing a zero value for ${enum.name}") + } + } + } + @OptIn(ExperimentalStdlibApi::class) // TODO move to build flag private fun generateMessage( type: MessageType, @@ -208,6 +245,8 @@ class SwiftGenerator private constructor( val typeSpecs = mutableListOf() + validateProto3DefaultsExist(type) + typeSpecs += TypeSpec.structBuilder(structType) .addModifiers(PUBLIC) .apply { @@ -449,18 +488,30 @@ class SwiftGenerator private constructor( .addParameter("from", reader, protoReader) .throws(true) .apply { - // Declare locals into which everything is writen before promoting to members. + // Declare locals into which everything is written before promoting to members. type.fields.forEach { field -> - val localType = if (field.isRepeated || field.isMap) { - field.typeName - } else { - field.typeName.makeOptional() + val localType = when (type.syntax) { + PROTO_2 -> if (field.isRepeated || field.isMap) { + field.typeName + } else { + field.typeName.makeOptional() + } + PROTO_3 -> if (field.isOptional || (field.isEnum && !field.isRepeated)) { + field.typeName.makeOptional() + } else { + field.typeName + } } - val initializer = when { - field.isMap -> "[:]" - field.isRepeated -> "[]" - else -> "nil" + + val initializer = when (type.syntax) { + PROTO_2 -> when { + field.isMap -> "[:]" + field.isRepeated -> "[]" + else -> "nil" + } + PROTO_3 -> field.proto3InitialValue } + addStatement("var %N: %T = %L", field.name, localType, initializer) } type.oneOfs.forEach { oneOf -> @@ -533,10 +584,17 @@ class SwiftGenerator private constructor( // Check required and bind members. addStatement("") type.fields.forEach { field -> - val initializer = if (field.isOptional || field.isRepeated || field.isMap) { - CodeBlock.of("%N", field.name) - } else { - CodeBlock.of("try %1T.checkIfMissing(%2N, %2S)", structType, field.name) + val initializer = when(type.syntax) { + PROTO_2 -> if (field.isOptional || field.isRepeated || field.isMap) { + CodeBlock.of("%N", field.name) + } else { + CodeBlock.of("try %1T.checkIfMissing(%2N, %2S)", structType, field.name) + } + PROTO_3 -> if (field.isEnum && !field.isRepeated) { + CodeBlock.of("try %1T.defaultIfMissing(%2N)", field.typeName.makeNonOptional(), field.name) + } else { + CodeBlock.of("%N", field.name) + } } addStatement("self.%N = %L", field.name, initializer) }