From e273a7e586510460c4d978e3acef7b09c3d71c55 Mon Sep 17 00:00:00 2001 From: Luca <5348596+lucaviolanti@users.noreply.github.com> Date: Mon, 28 Nov 2022 14:40:30 +0000 Subject: [PATCH] Allow tuning loader options (#333) --- .../src/main/scala/io/circe/yaml/Parser.scala | 54 ++++++++++++++---- .../scala/io/circe/yaml/parser/package.scala | 10 ++-- .../scala/io/circe/yaml/ParserTests.scala | 57 ++++++++++++++++++- 3 files changed, 104 insertions(+), 17 deletions(-) diff --git a/circe-yaml/src/main/scala/io/circe/yaml/Parser.scala b/circe-yaml/src/main/scala/io/circe/yaml/Parser.scala index 14d5de0a..7120a5be 100644 --- a/circe-yaml/src/main/scala/io/circe/yaml/Parser.scala +++ b/circe-yaml/src/main/scala/io/circe/yaml/Parser.scala @@ -19,6 +19,7 @@ package io.circe.yaml import cats.data.ValidatedNel import cats.syntax.either._ import io.circe._ +import io.circe.yaml.Parser.default.loaderOptions import org.yaml.snakeyaml.LoaderOptions import org.yaml.snakeyaml.Yaml import org.yaml.snakeyaml.constructor.SafeConstructor @@ -31,12 +32,16 @@ import scala.collection.JavaConverters._ import Parser._ final case class Parser( - maxAliasesForCollections: Int = 50 + maxAliasesForCollections: Int = Parser.defaultMaxAliasesForCollections, + nestingDepthLimit: Int = Parser.defaultNestingDepthLimit, + codePointLimit: Int = Parser.defaultCodePointLimit ) extends yaml.common.Parser { private val loaderOptions = { val options = new LoaderOptions() options.setMaxAliasesForCollections(maxAliasesForCollections) + options.setNestingDepthLimit(nestingDepthLimit) + options.setCodePointLimit(codePointLimit) options } @@ -47,15 +52,16 @@ final case class Parser( */ def parse(yaml: Reader): Either[ParsingFailure, Json] = for { parsed <- parseSingle(yaml) - json <- yamlToJson(parsed) + json <- yamlToJson(parsed, loaderOptions) } yield json def parse(yaml: String): Either[ParsingFailure, Json] = parse(new StringReader(yaml)) - def parseDocuments(yaml: Reader): Stream[Either[ParsingFailure, Json]] = parseStream(yaml) match { - case Left(error) => Stream(Left(error)) - case Right(stream) => stream.map(yamlToJson) - } + def parseDocuments(yaml: Reader): Stream[Either[ParsingFailure, Json]] = + parseStream(yaml) match { + case Left(error) => Stream(Left(error)) + case Right(stream) => stream.map(n => yamlToJson(n, loaderOptions)) + } def parseDocuments(yaml: String): Stream[Either[ParsingFailure, Json]] = parseDocuments(new StringReader(yaml)) @@ -67,6 +73,21 @@ final case class Parser( .catchNonFatal(new Yaml(loaderOptions).composeAll(reader).asScala.toStream) .leftMap(err => ParsingFailure(err.getMessage, err)) + def copy( + maxAliasesForCollections: Int = this.maxAliasesForCollections, + nestingDepthLimit: Int = this.nestingDepthLimit, + codePointLimit: Int = this.codePointLimit + ): Parser = new Parser(maxAliasesForCollections, nestingDepthLimit, codePointLimit) + + def copy(maxAliasesForCollections: Int): Parser = new Parser( + maxAliasesForCollections = maxAliasesForCollections, + nestingDepthLimit = this.nestingDepthLimit, + codePointLimit = this.codePointLimit + ) + + def this(maxAliasesForCollections: Int) = + this(maxAliasesForCollections, Parser.defaultNestingDepthLimit, Parser.defaultCodePointLimit) + final def decode[A: Decoder](input: Reader): Either[Error, A] = finishDecode(parse(input)) @@ -75,8 +96,16 @@ final case class Parser( } object Parser { + val defaultMaxAliasesForCollections: Int = 50 // to prevent YAML at + // https://en.wikipedia.org/wiki/Billion_laughs_attack + val defaultNestingDepthLimit: Int = 50 + val defaultCodePointLimit: Int = 3 * 1024 * 1024 // 3MB + val default: Parser = Parser() + def apply(maxAliasesForCollections: Int): Parser = + new Parser(maxAliasesForCollections = maxAliasesForCollections) + private[yaml] object CustomTag { def unapply(tag: Tag): Option[String] = if (!tag.startsWith(Tag.PREFIX)) Some(tag.getValue) @@ -84,7 +113,7 @@ object Parser { None } - private[yaml] class FlatteningConstructor extends SafeConstructor { + private[yaml] class FlatteningConstructor(val loaderOptions: LoaderOptions) extends SafeConstructor(loaderOptions) { def flatten(node: MappingNode): MappingNode = { flattenMapping(node) node @@ -94,9 +123,12 @@ object Parser { getConstructor(node).construct(node) } - private[yaml] def yamlToJson(node: Node): Either[ParsingFailure, Json] = { + private[yaml] def yamlToJson(node: Node): Either[ParsingFailure, Json] = + yamlToJson(node, loaderOptions) + + private[yaml] def yamlToJson(node: Node, loaderOptions: LoaderOptions): Either[ParsingFailure, Json] = { // Isn't thread-safe internally, may hence not be shared - val flattener: FlatteningConstructor = new FlatteningConstructor + val flattener: FlatteningConstructor = new FlatteningConstructor(loaderOptions) def convertScalarNode(node: ScalarNode) = Either .catchNonFatal(node.getTag match { @@ -146,7 +178,7 @@ object Parser { for { obj <- objEither key <- convertKeyNode(tup.getKeyNode) - value <- yamlToJson(tup.getValueNode) + value <- yamlToJson(tup.getValueNode, loaderOptions) } yield obj.add(key, value) } .map(Json.fromJsonObject) @@ -155,7 +187,7 @@ object Parser { .foldLeft(Either.right[ParsingFailure, List[Json]](List.empty[Json])) { (arrEither, node) => for { arr <- arrEither - value <- yamlToJson(node) + value <- yamlToJson(node, loaderOptions) } yield value :: arr } .map(arr => Json.fromValues(arr.reverse)) diff --git a/circe-yaml/src/main/scala/io/circe/yaml/parser/package.scala b/circe-yaml/src/main/scala/io/circe/yaml/parser/package.scala index 98a76471..a7203257 100644 --- a/circe-yaml/src/main/scala/io/circe/yaml/parser/package.scala +++ b/circe-yaml/src/main/scala/io/circe/yaml/parser/package.scala @@ -46,9 +46,11 @@ package object parser extends io.circe.yaml.common.Parser { Parser.default.decodeAccumulating[A](input) @deprecated("moved to Parser.CustomTag", since = "0.14.2") - private val loaderOptions = { + private val loaderOptions: LoaderOptions = { val options = new LoaderOptions() - options.setMaxAliasesForCollections(50) + options.setMaxAliasesForCollections(Parser.defaultMaxAliasesForCollections) + options.setNestingDepthLimit(Parser.defaultNestingDepthLimit) + options.setCodePointLimit(Parser.defaultCodePointLimit) options } @@ -69,8 +71,8 @@ package object parser extends io.circe.yaml.common.Parser { } @deprecated("moved to Parser.CustomTag", since = "0.14.2") - private[this] class FlatteningConstructor extends Parser.FlatteningConstructor + private[this] class FlatteningConstructor extends Parser.FlatteningConstructor(loaderOptions) @deprecated("moved to Parser.CustomTag", since = "0.14.2") - private[this] def yamlToJson(node: Node): Either[ParsingFailure, Json] = Parser.yamlToJson(node) + private[this] def yamlToJson(node: Node): Either[ParsingFailure, Json] = Parser.yamlToJson(node, loaderOptions) } diff --git a/circe-yaml/src/test/scala/io/circe/yaml/ParserTests.scala b/circe-yaml/src/test/scala/io/circe/yaml/ParserTests.scala index 4b4bb18f..3cd1035d 100644 --- a/circe-yaml/src/test/scala/io/circe/yaml/ParserTests.scala +++ b/circe-yaml/src/test/scala/io/circe/yaml/ParserTests.scala @@ -89,7 +89,6 @@ class ParserTests extends AnyFlatSpec with Matchers with EitherValues { .parse( "" ) - .right .value == Json.False ) } @@ -100,7 +99,6 @@ class ParserTests extends AnyFlatSpec with Matchers with EitherValues { .parse( " " ) - .right .value == Json.False ) } @@ -233,4 +231,59 @@ class ParserTests extends AnyFlatSpec with Matchers with EitherValues { assertResult(1)(result.size) assert(result.head.isLeft) } + + it should "parse when within depth limits" in { + assert( + Parser(nestingDepthLimit = 3) + .parse( + """ + | foo: + | bar: + | baz + |""".stripMargin + ) + .isRight + ) + } + + it should "fail to parse when depth limit is exceeded" in { + assert( + Parser(nestingDepthLimit = 1) + .parse( + """ + | foo: + | bar: + | baz + |""".stripMargin + ) + .isLeft + ) + } + + it should "parse when within code point limit" in { + assert( + Parser(codePointLimit = 1 * 1024 * 1024) // 1MB + .parse( + """ + | foo: + | bar: + | baz + |""".stripMargin + ) + .isRight + ) + } + + it should "fail to parse when code point limit is exceeded" in { + assert( + Parser(codePointLimit = 13) // 13B + .parse( + """ + | foo: + | bar + |""".stripMargin + ) + .isLeft + ) + } }