Skip to content

Commit

Permalink
Allow tuning loader options (#333)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucaviolanti authored Nov 28, 2022
1 parent e6cdde3 commit e273a7e
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 17 deletions.
54 changes: 43 additions & 11 deletions circe-yaml/src/main/scala/io/circe/yaml/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package io.circe.yaml
import cats.data.ValidatedNel
import cats.syntax.either._
import io.circe._
import io.circe.yaml.Parser.default.loaderOptions
import org.yaml.snakeyaml.LoaderOptions
import org.yaml.snakeyaml.Yaml
import org.yaml.snakeyaml.constructor.SafeConstructor
Expand All @@ -31,12 +32,16 @@ import scala.collection.JavaConverters._
import Parser._

final case class Parser(
maxAliasesForCollections: Int = 50
maxAliasesForCollections: Int = Parser.defaultMaxAliasesForCollections,
nestingDepthLimit: Int = Parser.defaultNestingDepthLimit,
codePointLimit: Int = Parser.defaultCodePointLimit
) extends yaml.common.Parser {

private val loaderOptions = {
val options = new LoaderOptions()
options.setMaxAliasesForCollections(maxAliasesForCollections)
options.setNestingDepthLimit(nestingDepthLimit)
options.setCodePointLimit(codePointLimit)
options
}

Expand All @@ -47,15 +52,16 @@ final case class Parser(
*/
def parse(yaml: Reader): Either[ParsingFailure, Json] = for {
parsed <- parseSingle(yaml)
json <- yamlToJson(parsed)
json <- yamlToJson(parsed, loaderOptions)
} yield json

def parse(yaml: String): Either[ParsingFailure, Json] = parse(new StringReader(yaml))

def parseDocuments(yaml: Reader): Stream[Either[ParsingFailure, Json]] = parseStream(yaml) match {
case Left(error) => Stream(Left(error))
case Right(stream) => stream.map(yamlToJson)
}
def parseDocuments(yaml: Reader): Stream[Either[ParsingFailure, Json]] =
parseStream(yaml) match {
case Left(error) => Stream(Left(error))
case Right(stream) => stream.map(n => yamlToJson(n, loaderOptions))
}

def parseDocuments(yaml: String): Stream[Either[ParsingFailure, Json]] = parseDocuments(new StringReader(yaml))

Expand All @@ -67,6 +73,21 @@ final case class Parser(
.catchNonFatal(new Yaml(loaderOptions).composeAll(reader).asScala.toStream)
.leftMap(err => ParsingFailure(err.getMessage, err))

def copy(
maxAliasesForCollections: Int = this.maxAliasesForCollections,
nestingDepthLimit: Int = this.nestingDepthLimit,
codePointLimit: Int = this.codePointLimit
): Parser = new Parser(maxAliasesForCollections, nestingDepthLimit, codePointLimit)

def copy(maxAliasesForCollections: Int): Parser = new Parser(
maxAliasesForCollections = maxAliasesForCollections,
nestingDepthLimit = this.nestingDepthLimit,
codePointLimit = this.codePointLimit
)

def this(maxAliasesForCollections: Int) =
this(maxAliasesForCollections, Parser.defaultNestingDepthLimit, Parser.defaultCodePointLimit)

final def decode[A: Decoder](input: Reader): Either[Error, A] =
finishDecode(parse(input))

Expand All @@ -75,16 +96,24 @@ final case class Parser(
}

object Parser {
val defaultMaxAliasesForCollections: Int = 50 // to prevent YAML at
// https://en.wikipedia.org/wiki/Billion_laughs_attack
val defaultNestingDepthLimit: Int = 50
val defaultCodePointLimit: Int = 3 * 1024 * 1024 // 3MB

val default: Parser = Parser()

def apply(maxAliasesForCollections: Int): Parser =
new Parser(maxAliasesForCollections = maxAliasesForCollections)

private[yaml] object CustomTag {
def unapply(tag: Tag): Option[String] = if (!tag.startsWith(Tag.PREFIX))
Some(tag.getValue)
else
None
}

private[yaml] class FlatteningConstructor extends SafeConstructor {
private[yaml] class FlatteningConstructor(val loaderOptions: LoaderOptions) extends SafeConstructor(loaderOptions) {
def flatten(node: MappingNode): MappingNode = {
flattenMapping(node)
node
Expand All @@ -94,9 +123,12 @@ object Parser {
getConstructor(node).construct(node)
}

private[yaml] def yamlToJson(node: Node): Either[ParsingFailure, Json] = {
private[yaml] def yamlToJson(node: Node): Either[ParsingFailure, Json] =
yamlToJson(node, loaderOptions)

private[yaml] def yamlToJson(node: Node, loaderOptions: LoaderOptions): Either[ParsingFailure, Json] = {
// Isn't thread-safe internally, may hence not be shared
val flattener: FlatteningConstructor = new FlatteningConstructor
val flattener: FlatteningConstructor = new FlatteningConstructor(loaderOptions)

def convertScalarNode(node: ScalarNode) = Either
.catchNonFatal(node.getTag match {
Expand Down Expand Up @@ -146,7 +178,7 @@ object Parser {
for {
obj <- objEither
key <- convertKeyNode(tup.getKeyNode)
value <- yamlToJson(tup.getValueNode)
value <- yamlToJson(tup.getValueNode, loaderOptions)
} yield obj.add(key, value)
}
.map(Json.fromJsonObject)
Expand All @@ -155,7 +187,7 @@ object Parser {
.foldLeft(Either.right[ParsingFailure, List[Json]](List.empty[Json])) { (arrEither, node) =>
for {
arr <- arrEither
value <- yamlToJson(node)
value <- yamlToJson(node, loaderOptions)
} yield value :: arr
}
.map(arr => Json.fromValues(arr.reverse))
Expand Down
10 changes: 6 additions & 4 deletions circe-yaml/src/main/scala/io/circe/yaml/parser/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ package object parser extends io.circe.yaml.common.Parser {
Parser.default.decodeAccumulating[A](input)

@deprecated("moved to Parser.CustomTag", since = "0.14.2")
private val loaderOptions = {
private val loaderOptions: LoaderOptions = {
val options = new LoaderOptions()
options.setMaxAliasesForCollections(50)
options.setMaxAliasesForCollections(Parser.defaultMaxAliasesForCollections)
options.setNestingDepthLimit(Parser.defaultNestingDepthLimit)
options.setCodePointLimit(Parser.defaultCodePointLimit)
options
}

Expand All @@ -69,8 +71,8 @@ package object parser extends io.circe.yaml.common.Parser {
}

@deprecated("moved to Parser.CustomTag", since = "0.14.2")
private[this] class FlatteningConstructor extends Parser.FlatteningConstructor
private[this] class FlatteningConstructor extends Parser.FlatteningConstructor(loaderOptions)

@deprecated("moved to Parser.CustomTag", since = "0.14.2")
private[this] def yamlToJson(node: Node): Either[ParsingFailure, Json] = Parser.yamlToJson(node)
private[this] def yamlToJson(node: Node): Either[ParsingFailure, Json] = Parser.yamlToJson(node, loaderOptions)
}
57 changes: 55 additions & 2 deletions circe-yaml/src/test/scala/io/circe/yaml/ParserTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ class ParserTests extends AnyFlatSpec with Matchers with EitherValues {
.parse(
""
)
.right
.value == Json.False
)
}
Expand All @@ -100,7 +99,6 @@ class ParserTests extends AnyFlatSpec with Matchers with EitherValues {
.parse(
" "
)
.right
.value == Json.False
)
}
Expand Down Expand Up @@ -233,4 +231,59 @@ class ParserTests extends AnyFlatSpec with Matchers with EitherValues {
assertResult(1)(result.size)
assert(result.head.isLeft)
}

it should "parse when within depth limits" in {
assert(
Parser(nestingDepthLimit = 3)
.parse(
"""
| foo:
| bar:
| baz
|""".stripMargin
)
.isRight
)
}

it should "fail to parse when depth limit is exceeded" in {
assert(
Parser(nestingDepthLimit = 1)
.parse(
"""
| foo:
| bar:
| baz
|""".stripMargin
)
.isLeft
)
}

it should "parse when within code point limit" in {
assert(
Parser(codePointLimit = 1 * 1024 * 1024) // 1MB
.parse(
"""
| foo:
| bar:
| baz
|""".stripMargin
)
.isRight
)
}

it should "fail to parse when code point limit is exceeded" in {
assert(
Parser(codePointLimit = 13) // 13B
.parse(
"""
| foo:
| bar
|""".stripMargin
)
.isLeft
)
}
}

0 comments on commit e273a7e

Please sign in to comment.