Skip to content

Commit

Permalink
Add tests for required feature
Browse files Browse the repository at this point in the history
  • Loading branch information
gregor-i committed Nov 18, 2024
1 parent daebaa3 commit a811bce
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,11 @@ private[compiler] class ParseFromGenerator(

private def usesBaseTypeInBuilder(field: FieldDescriptor) = field.isSingular

val requiredFieldMap: Map[FieldDescriptor, Int] =
message.fields.filter(fd => fd.isRequired || fd.noBoxRequired).zipWithIndex.toMap
private val requiredFields: Seq[(FieldDescriptor, Int)] =
message.fields.filter(fd => fd.isRequired || fd.noBoxRequired).zipWithIndex

private val requiredFieldMap: Map[FieldDescriptor, Int] =
requiredFields.toMap

val myFullScalaName = message.scalaType.fullNameWithMaybeRoot(message)

Expand Down Expand Up @@ -231,16 +234,15 @@ private[compiler] class ParseFromGenerator(
p.add(s"""if (${r}) {""")
.indent
.add("val __missingFields = Seq.newBuilder[_root_.scala.Predef.String]")
.print(requiredFieldMap.toSeq.sortBy(_._2)) {
case (p, (fieldDescriptor, fieldNumber)) =>
val bitmask = s"0x${"%x".format(1L << fieldNumber)}L"
val fieldVariable = s"__requiredFields${fieldNumber / 64}"
p.add(
s"""if (($fieldVariable & $bitmask) != 0L) __missingFields += "${fieldDescriptor.scalaName}""""
)
.print(requiredFields) { case (p, (fieldDescriptor, fieldNumber)) =>
val bitmask = f"${1L << fieldNumber}%#018xL"
val fieldVariable = s"__requiredFields${fieldNumber / 64}"
p.add(
s"""if (($fieldVariable & $bitmask) != 0L) __missingFields += "${fieldDescriptor.scalaName}""""
)
}
.add(
s"""val __message = s"Message missing required fields: $${__missingFields.result.mkString(", ")}"""",
s"""val __message = s"Message missing required fields: $${__missingFields.result().mkString(", ")}"""",
s"""throw new _root_.com.google.protobuf.InvalidProtocolBufferException(__message)"""
)
.outdent
Expand Down
51 changes: 50 additions & 1 deletion e2e/src/test/scala/RequiredFieldsSpec.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,60 @@
import com.google.protobuf.InvalidProtocolBufferException
import com.thesamet.proto.e2e.reqs.RequiredFields
import protobuf_unittest.unittest.TestEmptyMessage
import scalapb.UnknownFieldSet

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.must.Matchers

class RequiredFieldsSpec extends AnyFlatSpec with Matchers {

private val descriptor = RequiredFields.javaDescriptor

private def partialMessage(fields: Map[Int, Int]): Array[Byte] = {
val fieldSet = fields.foldLeft(UnknownFieldSet.empty){ case (fieldSet, (field, value)) =>
fieldSet
.withField(field, UnknownFieldSet.Field(varint = Seq(value)))
}

TestEmptyMessage(fieldSet).toByteArray
}

private val allFieldsSet: Map[Int, Int] = (100 to 164).map(i => (i, i)).toMap

"RequiredMessage" should "throw InvalidProtocolBufferException for empty byte array" in {
intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(Array[Byte]()))
val exception = intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(Array[Byte]()))

exception.getMessage() must startWith("Message missing required fields")
}

it should "throw no exception when all fields are set correctly" in {
val parsed = RequiredFields.parseFrom(partialMessage(allFieldsSet))
parsed must be(a[RequiredFields])
parsed.f0 must be(100)
parsed.f64 must be(164)
}

it should "throw an exception if a field is missing and name the missing field" in {
val fields = allFieldsSet.removed(123)
val exception = intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(partialMessage(fields)))

exception.getMessage() must be("Message missing required fields: f23")
}

it should "throw an exception if a multiple fields are missing and name those missing fields" in {
val fields = allFieldsSet.removed(123).removed(164).removed(130)
val exception = intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(partialMessage(fields)))

exception.getMessage() must be("Message missing required fields: f23, f30, f64")
}

it should "sort the missing fields by field number" in {
val fields = Map.empty[Int, Int]
val exception = intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(partialMessage(fields)))
val missingFields =exception.getMessage().stripPrefix("Message missing required fields: ").split(", ")

missingFields.sortBy[Int](field => descriptor.findFieldByName(field).getNumber()) must be(missingFields)

missingFields.toSeq mustBe Seq.tabulate(65)(i => s"f$i")
}
}

0 comments on commit a811bce

Please sign in to comment.